Refactor 'CompiledModule' to use cls constructors (#3329)

Refactors `CompiledModule` and its subclasses to use the name constructors `create_from_class` and `create_from_instance`. This makes it easier to add additional compilation paths to our TensorFlow integration tests.

Additional changes:
- Clean up docstrings and make type info more accurate.
- Explicitly differentiate between `backend_name` and `backend_id`:
  - `backend_name` is one of `[tf, tflite, iree_vmla, iree_llvmjit, iree_vulkan]`.
  - `backend_id` uniquely identifies each instantiated backend at test time for the purpose of saving compilation artifacts.
- Use a `defaultdict` to match IREE's behavior of compiling all methods into one binary.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
index 774412d..ba966db 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/compiler/__init__.py
@@ -99,7 +99,7 @@
 
 def tf_saved_model_to_compiler_module(
     saved_model_dir: str,
-    exported_names: Collection[str] = (),
+    exported_names: Sequence[str] = (),
     pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
     compiler_context: Optional[Context] = None) -> Module:
   """Converts a TensorFlow SavedModel into a MLIR module.
@@ -108,8 +108,7 @@
 
   Args:
     saved_model_dir: Directory of the saved model.
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
+    exported_names: Optional sequence representing the exported names to keep.
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to TF_IMPORT_PASS_PIPELINE.
     compiler_context: The pyiree.compiler.Context() backing the module.
@@ -131,18 +130,17 @@
 
 def compile_tf_saved_model(
     saved_model_dir: str,
-    exported_names: Collection[str] = (),
-    target_backends: Collection[str] = (),
+    exported_names: Sequence[str] = (),
+    target_backends: Sequence[str] = (),
     pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
     compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
   """Compiles a TensorFlow SavedModel to IREE in one shot.
 
   Args:
     saved_model_dir: Directory of the saved model.
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
-    target_backends: The specific target backends to compile for (defaults to
-      all compiled in targets).
+    exported_names: Optional sequence representing the exported names to keep.
+    target_backends: Optional sequence of specific target backends to compile
+      for (defaults to all compiled in targets).
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to TF_IMPORT_PASS_PIPELINE.
     compiler_context: The pyiree.compiler.Context() backing the module.
@@ -160,7 +158,7 @@
 def tf_signature_def_saved_model_to_compiler_module(
     saved_model_dir: str,
     saved_model_tags: Set[str] = set(),
-    exported_names: Collection[str] = [],
+    exported_names: Sequence[str] = (),
     pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
     compiler_context: Optional[Context] = None) -> Module:
   """Converts a TensorFlow SignatureDef SavedModel into a MLIR module.
@@ -168,8 +166,7 @@
   Args:
     saved_model_dir: Directory of the saved model.
     saved_model_tags: Optional set of tags to use when loading the model.
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
+    exported_names: Optional sequence representing the exported names to keep.
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to TF_IMPORT_PASS_PIPELINE.
     compiler_context: The pyiree.compiler.Context() backing the module.
@@ -194,8 +191,8 @@
 def compile_tf_signature_def_saved_model(
     saved_model_dir: str,
     saved_model_tags: Set[str] = set(),
-    exported_names: Collection[str] = (),
-    target_backends: Collection[str] = (),
+    exported_names: Sequence[str] = (),
+    target_backends: Sequence[str] = (),
     pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
     compiler_context: Optional[Context] = None) -> binding.OpaqueBlob:
   """Compiles a TensorFlow SignatureDef SavedModel to IREE in one shot.
@@ -203,10 +200,9 @@
   Args:
     saved_model_dir: Directory of the saved model.
     saved_model_tags: Optional set of tags to use when loading the model.
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
-    target_backends: The specific target backends to compile for (defaults to
-      all compiled in targets).
+    exported_names: Optional sequence representing the exported names to keep.
+    target_backends: Optional sequence of specific target backends to compile
+      for (defaults to all compiled in targets).
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to TF_IMPORT_PASS_PIPELINE.
     compiler_context: The pyiree.compiler.Context() backing the module.
@@ -222,7 +218,7 @@
 
 def tf_module_to_compiler_module(
     module: tf.Module,
-    exported_names: Collection[str] = (),
+    exported_names: Sequence[str] = (),
     pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
     compiler_context: Optional[Context] = None,
     saved_model_dir: str = None) -> Module:
@@ -230,8 +226,7 @@
 
   Args:
     module: The tf.Module instance to convert to MLIR
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
+    exported_names: Optional sequence representing the exported names to keep.
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to TF_IMPORT_PASS_PIPELINE.
     compiler_context: The pyiree.compiler.Context() backing the module.
@@ -259,8 +254,8 @@
 
 
 def compile_tf_module(module: tf.Module,
-                      exported_names: Collection[str] = (),
-                      target_backends: Collection[str] = (),
+                      exported_names: Sequence[str] = (),
+                      target_backends: Sequence[str] = (),
                       pass_pipeline: Sequence[str] = TF_IMPORT_PASS_PIPELINE,
                       compiler_context: Optional[Context] = None,
                       saved_model_dir: str = None):
@@ -268,10 +263,9 @@
 
   Args:
     module: The tf.Module instance to convert to MLIR
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
-    target_backends: The specific target backends to compile for (defaults to
-      all compiled in targets).
+    exported_names: Optional sequence representing the exported names to keep.
+    target_backends: Optional sequence of specific target backends to compile
+      for (defaults to all compiled in targets).
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to TF_IMPORT_PASS_PIPELINE.
     compiler_context: The pyiree.compiler.Context() backing the module.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index fa3df03..aa9481e 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -66,22 +66,22 @@
 
 
 def _parse_target_backends() -> Tuple[Sequence[str], Sequence[str]]:
-  """Decodes --target_backends and creates unique names for their artifacts."""
+  """Decodes --target_backends and creates unique ids for them."""
   backend_names = FLAGS.target_backends.split(",")
   backend_to_index = {k: 0 for k in backend_names if backend_names.count(k) > 1}
-  artifact_names = []
+  backend_ids = []
 
   # If there are multiple copies of the same backend_name, index them. e.g.
   # backend_names = ["tf", "iree_vmla", "tf"]
-  # --> artifact_names = ["tf_0", "iree_vmla", "tf_1"]
+  # --> backend_ids = ["tf_0", "iree_vmla", "tf_1"]
   for backend_name in backend_names:
     if backend_name in backend_to_index:
-      artifact_names.append(f"{backend_name}_{backend_to_index[backend_name]}")
+      backend_ids.append(f"{backend_name}_{backend_to_index[backend_name]}")
       backend_to_index[backend_name] += 1
     else:
-      artifact_names.append(backend_name)
+      backend_ids.append(backend_name)
 
-  return backend_names, artifact_names
+  return backend_names, backend_ids
 
 
 def get_target_backends() -> Sequence[tf_utils.BackendInfo]:
@@ -95,10 +95,10 @@
   """
   if FLAGS.target_backends is not None:
     logging.info("Using backends from command line: %s", FLAGS.target_backends)
-    backend_names, names = _parse_target_backends()
+    backend_names, backend_ids = _parse_target_backends()
     backends = [
-        tf_utils.BackendInfo(backend, name)
-        for backend, name in zip(backend_names, names)
+        tf_utils.BackendInfo(backend_name, backend_id)
+        for backend_name, backend_id in zip(backend_names, backend_ids)
     ]
   else:
     # If no backends are specified, use them all.
@@ -261,10 +261,11 @@
       # Extract metadata from module and function.
       self.module_name = module.module_name
       self.compiled_paths = module.compiled_paths
-      self.backend_name = module.backend
+      self.backend_name = module.backend_info.backend_name
+      self.backend_id = module.backend_info.backend_id
+      self.backend_driver = module.backend_info.driver
       self.iree_serializable = module.iree_serializable()
       self.tflite_serializable = module.tflite_serializable()
-      self.backend_driver = module.backend_driver
       self.function_name = function.__name__
       self.function_sourcefile = inspect.getsourcefile(function)
       source, start_line = inspect.getsourcelines(function)
@@ -276,9 +277,10 @@
       self.module_name = _load_dict["module_name"]
       self.compiled_paths = _load_dict["compiled_paths"]
       self.backend_name = _load_dict["backend_name"]
+      self.backend_id = _load_dict["backend_id"]
+      self.backend_driver = _load_dict["backend_driver"]
       self.iree_serializable = _load_dict["iree_serializable"]
       self.tflite_serializable = _load_dict["tflite_serializable"]
-      self.backend_driver = _load_dict["backend_driver"]
       self.function_name = _load_dict["function_name"]
       self.function_sourcefile = _load_dict["function_sourcefile"]
       self.function_line_numbers = _load_dict["function_line_numbers"]
@@ -286,7 +288,7 @@
       self.calls = _load_dict["calls"]
 
   def __str__(self):
-    header = (f"Trace of {self.module_name} compiled to '{self.backend_name}' "
+    header = (f"Trace of {self.module_name} compiled to '{self.backend_id}' "
               f"on function '{self.function_name}':")
     # Give each call a number so it's easier to compare between multiple traces.
     calls = [f"{i + 1}. {str(call)}" for i, call in enumerate(self.calls)]
@@ -327,11 +329,11 @@
 
       if not calls_match:
         logging.error("Comparision between '%s' and '%s' failed on method '%s'",
-                      ref_trace.backend_name, tar_trace.backend_name,
+                      ref_trace.backend_id, tar_trace.backend_id,
                       ref_call.method)
-        logging.error("Reference call '%s':\n%s", ref_trace.backend_name,
+        logging.error("Reference call '%s':\n%s", ref_trace.backend_id,
                       ref_call)
-        logging.error("Target call '%s':\n%s", tar_trace.backend_name, tar_call)
+        logging.error("Target call '%s':\n%s", tar_trace.backend_id, tar_call)
 
       traces_match = traces_match and calls_match
     return traces_match
@@ -434,14 +436,20 @@
       trace_dir: str, path to the directory to serialize the trace to.
     """
 
+    compiled_paths = None
+    if self.compiled_paths is not None:
+      # Convert to a dict to avoid the issues with serializing defaultdicts.
+      compiled_paths = dict(self.compiled_paths)
+
     # Python serialization.
     metadata = {
         "module_name": self.module_name,
-        "compiled_paths": self.compiled_paths,
+        "compiled_paths": compiled_paths,
         "backend_name": self.backend_name,
+        "backend_id": self.backend_id,
+        "backend_driver": self.backend_driver,
         "iree_serializable": self.iree_serializable,
         "tflite_serializable": self.tflite_serializable,
-        "backend_driver": self.backend_driver,
         "function_name": self.function_name,
         "function_sourcefile": self.function_sourcefile,
         "function_line_numbers": self.function_line_numbers,
@@ -493,7 +501,7 @@
 
 
 def _get_trace_dir(artifacts_dir: str, trace: Trace) -> str:
-  trace_dir = os.path.join(artifacts_dir, trace.backend_name, "traces",
+  trace_dir = os.path.join(artifacts_dir, trace.backend_id, "traces",
                            trace.function_name)
   os.makedirs(trace_dir, exist_ok=True)
   return trace_dir
@@ -559,17 +567,17 @@
 def compile_tf_module(
     module_class: Type[tf.Module], exported_names: Sequence[str] = ()
 ) -> Callable[[Any], Any]:
-  """CompiledModuleTestCase decorator that compiles a tf.Module.
-
-  A CompiledModule is created for each backend in --target_backends. They can
-  be accessed individually via self.compiled_modules.backend_name or as a union
-  via self.get_module().
+  """Compiles module_class to each backend that we test.
 
   Args:
     module_class: the tf.Module subclass to compile.
     exported_names: optional iterable of strings representing which of
       module_class's functions to compile. If exported_names is empty all
       functions will be compiled.
+
+  Returns:
+    A 'Modules' namedtuple containing the reference module, target modules and
+    artifacts directory.
   """
 
   # Setup the directory for saving compilation artifacts and traces.
@@ -580,7 +588,7 @@
                                           f"{FLAGS.reference_backend}_ref")
   tar_backend_infos = get_target_backends()
 
-  compile_backend = lambda backend_info: backend_info.compile(
+  compile_backend = lambda backend_info: backend_info.compile_from_class(
       module_class, exported_names, artifacts_dir)
 
   ref_module = compile_backend(ref_backend_info)
@@ -631,7 +639,7 @@
     failed_backend_indices = []
     for i, tar_trace in enumerate(tar_traces):
       logging.info("Comparing the reference backend '%s' with '%s'",
-                   ref_trace.backend_name, tar_trace.backend_name)
+                   ref_trace.backend_id, tar_trace.backend_id)
       traces_match = Trace.compare_traces(ref_trace, tar_trace)
       if not traces_match:
         failed_backend_indices.append(i)
@@ -649,7 +657,7 @@
     if failed_backend_indices:
       # Extract info for logging.
       failed_backends = [
-          tar_traces[i].backend_name for i in failed_backend_indices
+          tar_traces[i].backend_id for i in failed_backend_indices
       ]
       self.fail(
           "Comparision between the reference backend and the following targets "
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index 6157c68..e1fb6e3 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -117,8 +117,8 @@
       # Only outputs
       module.get_count()
 
-    module = tf_utils.TfCompiledModule(StatefulCountingModule,
-                                       tf_utils.BackendInfo('tf'))
+    module = tf_utils.TfCompiledModule.create_from_class(
+        StatefulCountingModule, tf_utils.BackendInfo('tf'))
     trace = tf_test_utils.Trace(module, trace_function)
     trace_function(tf_test_utils.TracedModule(module, trace))
 
@@ -140,13 +140,13 @@
       module.increment()
       module.decrement()
 
-    tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
-                                          tf_utils.BackendInfo('tf'))
+    tf_module = tf_utils.TfCompiledModule.create_from_class(
+        StatefulCountingModule, tf_utils.BackendInfo('tf'))
     tf_trace = tf_test_utils.Trace(tf_module, tf_function)
     tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
 
-    vmla_module = tf_utils.IreeCompiledModule(StatefulCountingModule,
-                                              tf_utils.BackendInfo('iree_vmla'))
+    vmla_module = tf_utils.IreeCompiledModule.create_from_class(
+        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
     vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
     vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
 
@@ -161,13 +161,13 @@
     def vmla_function(module):
       module.increment_by(np.array([22.], dtype=np.float32))
 
-    tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
-                                          tf_utils.BackendInfo('tf'))
+    tf_module = tf_utils.TfCompiledModule.create_from_class(
+        StatefulCountingModule, tf_utils.BackendInfo('tf'))
     tf_trace = tf_test_utils.Trace(tf_module, tf_function)
     tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
 
-    vmla_module = tf_utils.IreeCompiledModule(StatefulCountingModule,
-                                              tf_utils.BackendInfo('iree_vmla'))
+    vmla_module = tf_utils.IreeCompiledModule.create_from_class(
+        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
     vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
     vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
 
@@ -178,12 +178,12 @@
     def trace_function(module):
       module.increment()
       module.increment_by(np.array([81.], dtype=np.float32))
-      module.increment_by_max(
-          np.array([81], dtype=np.float32), np.array([92], dtype=np.float32))
+      module.increment_by_max(np.array([81], dtype=np.float32),
+                              np.array([92], dtype=np.float32))
       module.get_count()
 
-    module = tf_utils.IreeCompiledModule(StatefulCountingModule,
-                                         tf_utils.BackendInfo('iree_vmla'))
+    module = tf_utils.IreeCompiledModule.create_from_class(
+        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
     trace = tf_test_utils.Trace(module, trace_function)
     trace_function(tf_test_utils.TracedModule(module, trace))
 
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index fc8bab6..b3237f9 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -16,6 +16,7 @@
 
 # pylint: disable=protected-access
 
+import collections
 import os
 import random
 import re
@@ -91,7 +92,7 @@
 def _setup_mlir_crash_reproducer(
     function: Callable[[Any], Any],
     artifacts_dir: str,
-    backend_name: str,
+    backend_id: str,
 ) -> Callable[[Any], Any]:
   """Wraps `function` so that it a MLIR crash reproducer is saved if it crashes.
 
@@ -100,7 +101,7 @@
   Args:
     function: The callable to decorate.
     artifacts_dir: The directory to write the reproducer to.
-    backend_name: The name of the backend `function` compiles to.
+    backend_id: The unique backend name to use when writting the reproducer.
 
   Returns:
     A function with the same API as the passed function.
@@ -110,7 +111,7 @@
     # Set up a crash reproducer for debugging.
     if artifacts_dir is not None:
       compiler.Context.default_crash_reproducer_path = os.path.join(
-          artifacts_dir, f"reproducer__{backend_name}.mlir")
+          artifacts_dir, f"reproducer__{backend_id}.mlir")
     try:
       results = function(*args, **kwargs)
     except Exception:  # pylint: disable=broad-except
@@ -135,7 +136,7 @@
       MLIR for the module in TF's input dialect.
     iree_input.mlir:
       The MLIR above translated to IREE via compiler.TF_IMPORT_PASS_PIPELINE.
-    backend_name/compiled.vmfb:
+    backend_id/compiled.vmfb:
       A VM FlatBuffer compiled to the target backend from the IREE MLIR above.
 
   Args:
@@ -167,7 +168,7 @@
 
   compiled_path = None
   if artifacts_dir is not None:
-    backend_dir = os.path.join(artifacts_dir, backend_info.name)
+    backend_dir = os.path.join(artifacts_dir, backend_info.backend_id)
     os.makedirs(backend_dir, exist_ok=True)
     compiled_path = os.path.join(backend_dir, "compiled.vmfb")
     logging.info("Saving compiled IREE module to: %s", compiled_path)
@@ -180,7 +181,7 @@
     module: Type[tf.Module],
     backend_info: "BackendInfo",
     exported_names: Sequence[str] = (),
-    artifacts_dir: str = None
+    artifacts_dir: str = None,
 ) -> Tuple[compiler.binding.OpaqueBlob, Union[str, None]]:
   """Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
 
@@ -193,8 +194,7 @@
   Args:
     module: A tf.Module.
     backend_info: BackendInfo with the details for compiling module to IREE.
-    exported_names: Iterable of dotted function names to consider for
-      compilation.
+    exported_names: Optional sequence representing the exported names to keep.
     artifacts_dir: An optional string pointing to where compilation artifacts
       should be saved. No compilation artifacts will be saved if this is not
       provided.
@@ -212,29 +212,72 @@
                                                 artifacts_dir)
 
   _compile_module = _setup_mlir_crash_reproducer(_compile_module, artifacts_dir,
-                                                 backend_info.name)
+                                                 backend_info.backend_id)
   return _compile_module(module, exported_names, backend_info, artifacts_dir)
 
 
 class CompiledModule(object):
   """Base class for the TF and IREE compiled modules."""
 
-  def __init__(self, module_class: Type[tf.Module], backend_info: "BackendInfo",
-               exported_names: Sequence[str], artifacts_dir: str):
-    """Shared base constructor – not useful on its own."""
-    self._module_class = module_class
-    self._backend_info = backend_info
-    self._exported_names = exported_names
-    self._artifacts_dir = artifacts_dir
+  def __init__(
+      self,
+      module_name: str,
+      backend_info: "BackendInfo",
+      compiled_paths: Dict[str, str],
+  ):
+    """Shared base constructor – not useful on its own.
 
-    # Public attributes:
-    self.backend = self._backend_info.name
-    self.backend_driver = self._backend_info.driver
-    self.module_name = self._module_class.__name__
-    self.compiled_paths = None
+    Args:
+      module_name: A name for this compiled module. In most cases this will be
+        the name of the tf.Module subclass or instance that is compiled.
+      backend_info: BackendInfo with the details about compiling this module.
+      compiled_paths: A dictionary mapping compiled method names to file paths
+        corresponding to their serialized representations.
+    """
+    self.module_name = module_name
+    self.backend_info = backend_info
+    self.compiled_paths = compiled_paths
 
   def reinitialize(self):
-    """Reinitializes to the initial state of the passed module_class."""
+    """Reinitializes all stateful variables."""
+    raise NotImplementedError()
+
+  @classmethod
+  def create_from_class(cls,
+                        module_class: Type[tf.Module],
+                        backend_info: "BackendInfo",
+                        exported_names: Sequence[str] = (),
+                        artifacts_dir: str = None):
+    """Compile a tf.Module subclass to the target backend in backend_info.
+
+    Args:
+      module_class: The tf.Module subclass to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    raise NotImplementedError()
+
+  @classmethod
+  def create_from_instance(cls,
+                           module_instance: tf.Module,
+                           backend_info: "BackendInfo",
+                           exported_names: Sequence[str] = (),
+                           artifacts_dir: str = None):
+    """Compile a tf.Module instance to the target backend in backend_info.
+
+    This is only implemented for IreeCompiledModule.
+
+    Args:
+      module_instance: The tf.Module instance to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
     raise NotImplementedError()
 
   def iree_serializable(self):
@@ -244,14 +287,6 @@
     return False
 
 
-def _get_non_inhereted_function_names(cls):
-  """Gets all methods that cls has that its parents don't have."""
-  names = set(dir(cls))
-  for parent in cls.__bases__:
-    names -= set(dir(parent))
-  return list(names)
-
-
 class _FunctionWrapper(object):
 
   def get_serialized_values(self) -> Tuple[Tuple[str], Tuple[str]]:
@@ -277,57 +312,94 @@
 class IreeCompiledModule(CompiledModule):
   """Iree compiled module."""
 
-  def __init__(self,
-               module_class: Type[tf.Module],
-               backend_info: "BackendInfo",
-               exported_names: Sequence[str] = (),
-               artifacts_dir: str = None):
-    """Compile a tf.Module to the target backend in backend_info.
+  def __init__(
+      self,
+      module_name: str,
+      backend_info: "BackendInfo",
+      compiled_paths: Dict[str, str],
+      vm_module: rt.VmModule,
+      config: rt.Config,
+  ):
+    """Base constructor – Use one of the named constructors instead.
 
     Args:
-      module_class: the tf.Module subclass to compile.
-      backend_info: an element of BackendInfo corresponding to the IREE backend
-        to compile to.
-      exported_names: an optional iterable of strings representing which of the
-        module_class's functions to compile. If exported_names is empty all
-        functions will be compiled.
-      artifacts_dir: an optional path to save compilation artifacts to.
+      module_name: A name for this compiled module. In most cases this will be
+        the name of the tf.Module subclass or instance that is compiled.
+      backend_info: BackendInfo with the details about compiling this module.
+      compiled_paths: A dictionary mapping compiled method names to file paths
+        corresponding to their serialized representations.
+      vm_module: A rt.VmModule containing compilation info to wrap.
+      config: A rt.Config containing compilation info to wrap.
     """
-    super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+    super().__init__(module_name, backend_info, compiled_paths)
+    self._vm_module = vm_module
+    self._config = config
+    self.reinitialize()
 
+  @classmethod
+  def create_from_class(cls,
+                        module_class: Type[tf.Module],
+                        backend_info: "BackendInfo",
+                        exported_names: Sequence[str] = (),
+                        artifacts_dir: str = None):
+    """Compile a tf.Module subclass to the target backend in backend_info.
+
+    Args:
+      module_class: The tf.Module subclass to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
     set_random_seed()
-    self._module_blob, compiled_path = _incrementally_compile_tf_module(
-        module=module_class(),
+    module_instance = module_class()
+    return cls.create_from_instance(module_instance, backend_info,
+                                    exported_names, artifacts_dir)
+
+  @classmethod
+  def create_from_instance(cls,
+                           module_instance: tf.Module,
+                           backend_info: "BackendInfo",
+                           exported_names: Sequence[str] = (),
+                           artifacts_dir: str = None):
+    """Compile a tf.Module instance to the target backend in backend_info.
+
+    Args:
+      module_instance: The tf.Module instance to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    module_blob, compiled_path = _incrementally_compile_tf_module(
+        module=module_instance,
         backend_info=backend_info,
         exported_names=exported_names,
         artifacts_dir=artifacts_dir)
-    self._module = rt.VmModule.from_flatbuffer(self._module_blob)
-    self._config = rt.Config(driver_name=backend_info.driver)
+    vm_module = rt.VmModule.from_flatbuffer(module_blob)
+    config = rt.Config(driver_name=backend_info.driver)
 
-    self.compiled_paths = None
+    compiled_paths = None
     if compiled_path is not None:
-      if not len(exported_names):
-        # Get all method names on 'module_class' that aren't on 'tf.Module'.
-        # This doesn't address all possbile scenarios.
-        # TODO(meadowlark): Figure out how to get a list of all of the functions
-        # that this module has access to via `pyiree.rt.system_api.BoundModule`.
-        exported_names = _get_non_inhereted_function_names(module_class)
-      self.compiled_paths = dict([
-          (method, compiled_path) for method in exported_names
-      ])
+      # IREE bundles every compiled method into the same compiled module.
+      compiled_paths = collections.defaultdict(lambda: compiled_path)
 
-    self.reinitialize()
+    module_name = type(module_instance).__name__
+
+    return cls(module_name, backend_info, compiled_paths, vm_module, config)
 
   def reinitialize(self):
-    """Reinitializes to the initial state of the passed module_class."""
+    """Reinitializes all stateful variables."""
     # set_random_seed is not needed here because the model_class.__init__ is not
     # called.
-    self._context = rt.SystemContext(modules=[self._module],
+    self._context = rt.SystemContext(modules=[self._vm_module],
                                      config=self._config)
 
   def __getattr__(self, attr: str) -> _IreeFunctionWrapper:
     # Try to resolve it as a function.
-    m = self._context.modules[self._module.name]
+    m = self._context.modules[self._vm_module.name]
     f = m[attr]
     return _IreeFunctionWrapper(self._context, f)
 
@@ -376,29 +448,54 @@
   normalize TensorFlow's output to Numpy.
   """
 
-  def __init__(self,
-               module_class: Type[tf.Module],
-               backend_info: "BackendInfo",
-               exported_names: Sequence[str] = (),
-               artifacts_dir: str = None):
-    """Wrap a tf.Module in a TFCompiledModule facade.
+  def __init__(
+      self,
+      module_name: str,
+      backend_info: "BackendInfo",
+      constructor: Callable[[], tf.Module],
+      exported_names: Sequence[str],
+  ):
+    """Base constructor – Use one of the named constructors instead.
 
     Args:
-      module_class: the tf.Module subclass to 'compile'.
-      backend_info: one of the 'tf*' elements in BackendInfo.
+      module_name: A name for this compiled module. In most cases this will be
+        the name of the tf.Module subclass or instance that is compiled.
+      backend_info: BackendInfo with the details about compiling this module.
+      constructor: A callable (class or function) which returns the tf.Module
+        subclass instance to wrap.
       exported_names: an optional iterable of strings representing which of the
-        module_class's functions should be callable. If exported_names is empty
-        then all functions will be callable.
-      artifacts_dir: an optional path to save compilation artifacts to. Has no
-        effect for this subclass as nothing is compiled.
+        tf.Module subclass instance's functions should be callable. If
+        exported_names is empty then all functions will be callable.
     """
-    super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+    super().__init__(module_name, backend_info, compiled_paths=None)
+    self._constructor = constructor
+    self._exported_names = exported_names
     self.reinitialize()
 
+  @classmethod
+  def create_from_class(cls,
+                        module_class: Type[tf.Module],
+                        backend_info: "BackendInfo",
+                        exported_names: Sequence[str] = (),
+                        artifacts_dir: str = None):
+    """Compile a tf.Module subclass to the target backend in backend_info.
+
+    Args:
+      module_class: The tf.Module subclass to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
+    module_name = module_class.__name__
+    constructor = module_class
+    return cls(module_name, backend_info, constructor, exported_names)
+
   def reinitialize(self):
-    """Reinitializes to the initial state of the passed module_class."""
+    """Reinitializes all stateful variables."""
     set_random_seed()
-    self._tf_module = self._module_class()
+    self._tf_module = self._constructor()
 
   def __getattr__(self, attr: str) -> _TfFunctionWrapper:
     # Try to resolve it as a function.
@@ -412,6 +509,14 @@
     return _TfFunctionWrapper(f)
 
 
+def _get_non_inhereted_function_names(cls):
+  """Gets all methods that cls has that its parents don't have."""
+  names = set(dir(cls))
+  for parent in cls.__bases__:
+    names -= set(dir(parent))
+  return list(names)
+
+
 def _get_concrete_functions(module_class: Type[tf.Module],
                             exported_names: Sequence[str] = ()):
   """Get concrete functions from non-inherited methods or exported_names."""
@@ -425,12 +530,12 @@
   return functions, exported_names
 
 
-def compile_to_tflite(
+def tf_module_to_tflite_interpreters(
     module_class: Type[tf.Module],
     exported_names: Sequence[str] = (),
     artifacts_dir: str = None
 ) -> Tuple[Dict[str, tf.lite.Interpreter], Union[Dict[str, str]], None]:
-  """Compile a dict of TFLite interpreters for the methods on module_class.
+  """Compile a tf.Module to TFLite interpreters for each of its methods.
 
   Args:
     module_class: A tf.Module subclass to compile with TFLite. If module_class
@@ -463,22 +568,14 @@
     if artifacts_dir is not None:
       compiled_paths[name] = tflite_path
 
+  # Convert module_class's methods into TFLite module byte-strings.
   tflite_modules = []
-  names = []
-  if hasattr(module_class, "get_legacy_tflite_saved_model_converter_kwargs"):
-    kwargs = module_class.get_legacy_tflite_saved_model_converter_kwargs()
-    converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(
-        kwargs["model_path"],
-        input_arrays=kwargs["input_arrays"],
-        output_arrays=kwargs["output_arrays"])
+  functions, names = _get_concrete_functions(module_class, exported_names)
+  for function in functions:
+    converter = tf.lite.TFLiteConverter.from_concrete_functions([function])
     tflite_modules.append(converter.convert())
-    names.append(kwargs["exported_name"])
-  else:
-    functions, names = _get_concrete_functions(module_class, exported_names)
-    for function in functions:
-      converter = tf.lite.TFLiteConverter.from_concrete_functions([function])
-      tflite_modules.append(converter.convert())
 
+  # Load each of the converted methods above into tf.lite.Interpreters.
   for name, tflite_module in zip(names, tflite_modules):
     if artifacts_dir is None:
       with tempfile.TemporaryDirectory() as base_dir:
@@ -537,18 +634,50 @@
 class TfLiteCompiledModule(CompiledModule):
   """Compiles a tf.Module with TFLite and allows it to be called."""
 
-  def __init__(self,
-               module_class: Type[tf.Module],
-               backend_info: "BackendInfo",
-               exported_names: Sequence[str] = (),
-               artifacts_dir: str = None):
-    super().__init__(module_class, backend_info, exported_names, artifacts_dir)
+  def __init__(
+      self,
+      module_name: str,
+      backend_info: "BackendInfo",
+      compiled_paths: Dict[str, str],
+      interpreters: Dict[str, tf.lite.Interpreter],
+  ):
+    """Base constructor – Use one of the named constructors instead.
+
+    Args:
+      module_name: A name for this compiled module. In most cases this will be
+        the name of the tf.Module subclass or instance that is compiled.
+      backend_info: BackendInfo with the details about compiling this module.
+      compiled_paths: A dictionary mapping compiled method names to file paths
+        corresponding to their serialized representations.
+      interpreters: A dict of tf.lite.Interpreters to make callable.
+    """
+    super().__init__(module_name, backend_info, compiled_paths)
+    self._interpreters = interpreters
+
+  @classmethod
+  def create_from_class(cls,
+                        module_class: Type[tf.Module],
+                        backend_info: "BackendInfo",
+                        exported_names: Sequence[str] = (),
+                        artifacts_dir: str = None):
+    """Compile a tf.Module subclass to the target backend in backend_info.
+
+    Args:
+      module_class: The tf.Module subclass to compile.
+      backend_info: BackendInfo with the details for compiling module to IREE.
+      exported_names: Optional sequence representing the exported names to keep.
+      artifacts_dir: An optional string pointing to where compilation artifacts
+        should be saved. No compilation artifacts will be saved if this is not
+        provided.
+    """
     set_random_seed()
-    self._interpreters, self.compiled_paths = compile_to_tflite(
+    interpreters, compiled_paths = tf_module_to_tflite_interpreters(
         module_class, exported_names, artifacts_dir)
+    module_name = module_class.__name__
+    return cls(module_name, backend_info, compiled_paths, interpreters)
 
   def reinitialize(self):
-    """Reinitializes to the initial state of the passed module_class."""
+    """Reinitializes all stateful variables."""
     # This is a noop because TFLite (mostly) doesn't support stateful modules.
     pass
 
@@ -594,36 +723,43 @@
       },
   }
 
-  def __init__(self, backend_name: str, artifact_name: str = None):
+  def __init__(self, backend_name: str, backend_id: str = None):
     """Creates a BackendInfo with the compilation details for backend_name.
 
     Args:
       backend_name: a str specifying which backend to use. Should be one of
         'tf', 'iree_vmla', 'iree_llvmjit', 'iree_vulkan'.
-      artifact_name: an optional str specifying what name to use when saving
-        compiled artifacts.
+      backend_id: an optional str specifying what name to use when saving
+        compiled artifacts. Must satisfy `backend_id.startswith(backend_name)`.
 
     Raises:
       KeyError: if backend_name is not one of ['tf', 'iree_vmla',
       'iree_llvmjit', 'iree_vulkan'].
+      ValueError: if backend_id doesn't start with backend_name.
     """
     if backend_name not in self._name_to_info:
       raise KeyError(
           "Expected backend_name to be one of "
           f"{list(self._name_to_info.keys())} but got '{backend_name}'.")
+    if backend_id is not None and not backend_id.startswith(backend_name):
+      raise ValueError(f"Expected backend_id to start with '{backend_name}' "
+                       f"but got '{backend_id}'.")
+
+    self.backend_name = backend_name
+    self.backend_id = backend_name if backend_id is None else backend_id
+
     info = self._name_to_info[backend_name]
     self._compiled_module_class = info["compiled_module_class"]
     self.driver = info["driver"]
     self.compiler_targets = info["compiler_targets"]
-    self.name = backend_name if artifact_name is None else artifact_name
 
-  def compile(self,
-              module: Type[tf.Module],
-              exported_names: Sequence[str] = (),
-              artifacts_dir: str = None) -> CompiledModule:
+  def compile_from_class(self,
+                         module_class: Type[tf.Module],
+                         exported_names: Sequence[str] = (),
+                         artifacts_dir: str = None) -> CompiledModule:
     """Creates a 'CompiledModule' for this backend."""
-    return self._compiled_module_class(module, self, exported_names,
-                                       artifacts_dir)
+    return self._compiled_module_class.create_from_class(
+        module_class, self, exported_names, artifacts_dir)
 
   @classmethod
   def get_all_backends(cls) -> Sequence["BackendInfo"]:
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index 1a441d0..a9c259c 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -89,7 +89,7 @@
   ])
   def test_unaltered_state(self, backend_name):
     backend_info = tf_utils.BackendInfo(backend_name)
-    module = backend_info.compile(StatefulCountingModule)
+    module = backend_info.compile_from_class(StatefulCountingModule)
 
     # Test that incrementing works properly.
     self.assertEqual([0.], module.get_count())
@@ -126,8 +126,8 @@
     backend_info = tf_utils.BackendInfo(backend_name)
 
     # Test compilation is the same.
-    module_1 = backend_info.compile(RandomInitModule)
-    module_2 = backend_info.compile(RandomInitModule)
+    module_1 = backend_info.compile_from_class(RandomInitModule)
+    module_2 = backend_info.compile_from_class(RandomInitModule)
     self.assertAllEqual(module_1.get(), module_2.get())
 
     # Test reinitialization is the same.
diff --git a/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py b/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py
index 78baeb6..df9b317 100644
--- a/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py
+++ b/integrations/tensorflow/bindings/python/pyiree/xla/compiler/__init__.py
@@ -54,7 +54,7 @@
 def xla_load_module_proto(
     xla_computation,
     compiler_context: Optional[Context] = None,
-    exported_names: Collection[str] = (),
+    exported_names: Sequence[str] = (),
     pass_pipeline: Sequence[str] = XLA_IMPORT_PASS_PIPELINE) -> Module:
   """Loads a XLA saved model from its persistent representation.
 
@@ -63,8 +63,7 @@
   Args:
     xla_computation: XLA Computation generate from XLA Python client
     compiler_context: The pyiree.compiler.Context() backing the module.
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
+    exported_names: Optional sequence representing the exported names to keep.
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to XLA_IMPORT_PASS_PIPELINE.
 
@@ -85,21 +84,20 @@
 def xla_compile_module_proto(
     xla_computation,
     compiler_context: Optional[Context] = None,
-    exported_names: Collection[str] = (),
+    exported_names: Sequence[str] = (),
     pass_pipeline: Sequence[str] = XLA_IMPORT_PASS_PIPELINE,
-    target_backends: Collection[str] = ()
+    target_backends: Sequence[str] = ()
 ) -> binding.OpaqueBlob:
   """Loads and compiles a XLA saved model in one shot.
 
   Args:
     xla_computation: XLA Computation generate from XLA Python client
     compiler_context: The pyiree.compiler.Context() backing the module.
-    exported_names: Optional tuple of strings representing the exported names to
-      keep.
+    exported_names: Optional sequence representing the exported names to keep.
     pass_pipeline: Passes to run on the imported module prior to returning.
       Defaults to XLA_IMPORT_PASS_PIPELINE.
-    target_backends: The specific target backends to compile for (defaults to
-      all compiled in targets).
+    target_backends: Optional sequence of specific target backends to compile
+      for (defaults to all compiled in targets).
 
   Returns:
     An OpaqueBlob representing the compiled module.
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 2a0b458..7a07478 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -218,7 +218,6 @@
     name = "mobile_bert_squad_tests",
     backends_to_srcs = {
         "tf": ["mobile_bert_squad_test.py"],
-        "tflite": ["mobile_bert_squad_test.py"],
     },
     reference_backend = "tf",
     tags = [
@@ -234,6 +233,7 @@
 iree_e2e_test_suite(
     name = "mobile_bert_squad_tests_failing",
     backends_to_srcs = {
+        "tflite": ["mobile_bert_squad_test.py"],
         "iree_vmla": ["mobile_bert_squad_test.py"],
         "iree_llvmjit": ["mobile_bert_squad_test.py"],
         "iree_vulkan": ["mobile_bert_squad_test.py"],