Refactor BackendInfo, remove tf_also, fully support duplicate backends (#2738)

`BackendInfo` is refactored from a wrapper around a set of `NamedTuple`s into it's own class. This allows for the name attribute to be overridden, which is necessary for supporting saving artifacts for duplicate backends (e.g. two copies of `"iree_vmla"`). It also provides a cleaner API for compiling modules and getting all `BackendInfo` configurations.

Changes this allows:

- Compilation artifacts and traces from duplicate target backends will not overwrite each other. Instead artifacts for each backend are saved with an index (e.g. `trace__iree_vmla_0.txt`, `trace__iree_vmla_1.txt` and so on).
- The reference backend will have `_ref` appended to it's name (e.g. `trace__tf_ref.txt` or `trace__iree_vmla_ref.txt`).
- `tf_also` is removed.


Example artifacts:

```
SimpleArithmeticModule
├── compiled__iree_llvmjit.vmfb
├── compiled__iree_vmla.vmfb
├── compiled__iree_vulkan.vmfb
├── iree_input.mlir
├── tf_input.mlir
└── traces
    ├── simple_matmul__iree_llvmjit.txt
    ├── simple_matmul__iree_vmla.txt
    ├── simple_matmul__iree_vulkan.txt
    ├── simple_matmul__tf_ref.txt
    ├── simple_matmul__tf.txt
    ├── simple_mul__iree_llvmjit.txt
    ├── simple_mul__iree_vmla.txt
    ├── simple_mul__iree_vulkan.txt
    ├── simple_mul__tf_ref.txt
    └── simple_mul__tf.txt
```
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 2a4de30..b3235b3 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -67,17 +67,23 @@
   return artifacts_dir
 
 
-def _parse_target_backends(target_backends):
-  """Decodes a comma-delimited string of backends into BackendInfo objects."""
-  backends = []
-  for backend_name in target_backends.split(","):
-    if backend_name not in tf_utils.BackendInfo.ALL.keys():
-      raise ValueError(
-          "Invalid backend specification string '{}', unexpected name '{}';"
-          " valid names are '{}'".format(target_backends, backend_name,
-                                         tf_utils.BackendInfo.ALL.keys()))
-    backends.append(tf_utils.BackendInfo.ALL[backend_name])
-  return backends
+def _parse_target_backends():
+  """Decodes --target_backends and creates unique names for their artifacts."""
+  backend_names = FLAGS.target_backends.split(",")
+  backend_to_index = {k: 0 for k in backend_names if backend_names.count(k) > 1}
+  artifact_names = []
+
+  # If there are multiple copies of the same backend_name, index them. e.g.
+  # backend_names = ["tf", "iree_vmla", "tf"]
+  # --> artifact_names = ["tf_0", "iree_vmla", "tf_1"]
+  for backend_name in backend_names:
+    if backend_name in backend_to_index:
+      artifact_names.append(f"{backend_name}_{backend_to_index[backend_name]}")
+      backend_to_index[backend_name] += 1
+    else:
+      artifact_names.append(backend_name)
+
+  return backend_names, artifact_names
 
 
 def get_target_backends():
@@ -91,10 +97,14 @@
   """
   if FLAGS.target_backends is not None:
     logging.info("Using backends from command line: %s", FLAGS.target_backends)
-    backends = _parse_target_backends(FLAGS.target_backends)
+    backend_names, names = _parse_target_backends()
+    backends = [
+        tf_utils.BackendInfo(backend, name)
+        for backend, name in zip(backend_names, names)
+    ]
   else:
     # If no backends are specified, use them all.
-    backends = list(tf_utils.BackendInfo.ALL.values())
+    backends = tf_utils.BackendInfo.get_all_backends()
   return backends
 
 
@@ -406,8 +416,8 @@
 
   @classmethod
   def _compile(cls, backend_info):
-    return backend_info.CompiledModule(cls._module_class, backend_info,
-                                       cls._exported_names, cls._artifacts_dir)
+    return backend_info.compile(cls._module_class, cls._exported_names,
+                                cls._artifacts_dir)
 
   @classmethod
   def setUpClass(cls):
@@ -421,25 +431,15 @@
     # Setup the directory for saving compilation artifacts and traces.
     cls._artifacts_dir = _setup_artifacts_dir(cls._module_class.__name__)
 
-    # Setup crash reproducer for the test.
-    crash_reproducer_path = os.path.join(cls._artifacts_dir, "reproducer.mlir")
-    compiler.Context.default_crash_reproducer_path = crash_reproducer_path
-
     # Create a CompiledModule for the reference backend and each target backend.
-    try:
-      ref_backend_info = tf_utils.BackendInfo.ALL[FLAGS.reference_backend]
-      cls._ref_module = cls._compile(ref_backend_info)
+    ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
+                                            f"{FLAGS.reference_backend}_ref")
+    cls._ref_module = cls._compile(ref_backend_info)
 
-      tar_backend_infos = get_target_backends()
-      cls._tar_modules = [
-          cls._compile(backend_info) for backend_info in tar_backend_infos
-      ]
-    finally:
-      # TODO(meadowlark): Move this into tf_util.compile_tf_module to prevent
-      # overwritting `reproducer.mlir`.
-      # Disable crash reproducer (to avoid inadvertently overwriting this
-      # path if there are multiple TestCases in the same file).
-      compiler.Context.default_crash_reproducer_path = None
+    tar_backend_infos = get_target_backends()
+    cls._tar_modules = [
+        cls._compile(backend_info) for backend_info in tar_backend_infos
+    ]
 
   def setUp(self):
     # Ran before each unit test.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index f21521a..a0f24c5 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -112,7 +112,7 @@
       module.get_count()
 
     module = tf_utils.TfCompiledModule(StatefulCountingModule,
-                                       tf_utils.BackendInfo.ALL['tf'])
+                                       tf_utils.BackendInfo('tf'))
     trace = tf_test_utils.Trace(module, trace_function)
     trace_function(tf_test_utils.TracedModule(module, trace))
 
@@ -135,12 +135,12 @@
       module.decrement()
 
     tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
-                                          tf_utils.BackendInfo.ALL['tf'])
+                                          tf_utils.BackendInfo('tf'))
     tf_trace = tf_test_utils.Trace(tf_module, tf_function)
     tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
 
     vmla_module = tf_utils.IreeCompiledModule(
-        StatefulCountingModule, tf_utils.BackendInfo.ALL['iree_vmla'])
+        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
     vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
     vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
 
@@ -156,12 +156,12 @@
       module.increment_by(np.array([22.], dtype=np.float32))
 
     tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
-                                          tf_utils.BackendInfo.ALL['tf'])
+                                          tf_utils.BackendInfo('tf'))
     tf_trace = tf_test_utils.Trace(tf_module, tf_function)
     tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
 
     vmla_module = tf_utils.IreeCompiledModule(
-        StatefulCountingModule, tf_utils.BackendInfo.ALL['iree_vmla'])
+        StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
     vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
     vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
 
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index f67119e..d966346 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -16,7 +16,6 @@
 
 # pylint: disable=protected-access
 
-import collections
 import os
 import random
 import re
@@ -49,14 +48,14 @@
   return np.arange(np.prod(shape), dtype=dtype).reshape(shape)
 
 
-def backends_to_str(target_backends):
-  """Creates a flattened and normalized string representing target_backends."""
-  normalized_backends = []
-  for backend in target_backends:
+def backends_to_str(backend_infos):
+  """Creates a normalized string representing the provided backends."""
+  normalized_names = []
+  for backend_info in backend_infos:
     # Remove unusual characters and ensure names don't end or start in "_".
-    backend = re.sub("[^0-9a-zA-Z_]+", "_", backend)
-    normalized_backends.append(backend.strip("_"))
-  return "__".join(normalized_backends)
+    name = re.sub("[^0-9a-zA-Z_]+", "_", backend_info.name)
+    normalized_names.append(name.strip("_"))
+  return "__".join(normalized_names)
 
 
 def to_mlir_type(dtype):
@@ -90,7 +89,7 @@
 
 
 def compile_tf_module(tf_module,
-                      target_backends=(),
+                      backend_infos=(),
                       exported_names=(),
                       artifacts_dir=None):
   """Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
@@ -112,7 +111,7 @@
 
   Args:
     tf_module: A tf.Module.
-    target_backends: Iterable of string backend names to compile for.
+    backend_infos: Iterable of BackendInfo names to compile for.
     exported_names: Iterable of dotted function names to consider for
       compilation.
     artifacts_dir: An optional string pointing to where compilation artifacts
@@ -124,50 +123,62 @@
 
   def _compile_from_path(sm_path):
     """Helper function for compile_tf_module."""
-    # We break up the compilation here so we can save intermediary artifacts.
-    compiler_context = compiler.Context()
-
-    # Convert the tf_module into raw TF input MLIR.
-    compiler_module = compiler.tf_load_saved_model(
-        sm_path,
-        exported_names=exported_names,
-        compiler_context=compiler_context,
-        pass_pipeline=())
-
     if artifacts_dir is not None:
-      tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
-      logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
-      with open(tf_mlir_path, "w") as f:
-        f.write(compiler_module.to_asm())
+      # Set up a crash reproducer for debugging.
+      compiler.Context.default_crash_reproducer_path = os.path.join(
+          artifacts_dir, f"reproducer__{backends_string}.mlir")
+    try:
+      # We break up the compilation here so we can save intermediary artifacts.
+      compiler_context = compiler.Context()
 
-    # Now run the passes manually that tf_load_saved_model would usually do.
-    compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
+      # Convert the tf_module into raw TF input MLIR.
+      compiler_module = compiler.tf_load_saved_model(
+          sm_path,
+          exported_names=exported_names,
+          compiler_context=compiler_context,
+          pass_pipeline=())
 
-    if artifacts_dir is not None:
-      iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
-      logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
-      with open(iree_mlir_path, "w") as f:
-        f.write(compiler_module.to_asm())
+      if artifacts_dir is not None:
+        tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
+        logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
+        with open(tf_mlir_path, "w") as f:
+          f.write(compiler_module.to_asm())
 
-    compiled_module = compiler_module.compile(target_backends=target_backends)
-    if artifacts_dir is not None:
-      compiled_name = f"compiled__{backends_to_str(target_backends)}.vmfb"
-      compiled_path = os.path.join(artifacts_dir, compiled_name)
-      logging.info("Saving compiled IREE module to: %s", compiled_path)
-      with open(compiled_path, "wb") as f:
-        f.write(compiled_module)
+      # Now run the passes manually that tf_load_saved_model would usually do.
+      compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
 
-    return compiled_module
+      if artifacts_dir is not None:
+        iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
+        logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
+        with open(iree_mlir_path, "w") as f:
+          f.write(compiler_module.to_asm())
+
+      target_backends = []
+      for backend_info in backend_infos:
+        target_backends.extend(backend_info.compiler_targets)
+      compiled_module = compiler_module.compile(target_backends=target_backends)
+
+      if artifacts_dir is not None:
+        compiled_name = f"compiled__{backends_string}.vmfb"
+        compiled_path = os.path.join(artifacts_dir, compiled_name)
+        logging.info("Saving compiled IREE module to: %s", compiled_path)
+        with open(compiled_path, "wb") as f:
+          f.write(compiled_module)
+
+      return compiled_module
+    except Exception:  # pylint: disable=broad-except
+      if artifacts_dir is not None:
+        # Disable the crash reproducer (to avoid inadvertently overwriting it).
+        compiler.Context.default_crash_reproducer_path = None
+      raise
 
   options = tf.saved_model.SaveOptions(save_debug_info=True)
+  backends_string = backends_to_str(backend_infos)
   if artifacts_dir is not None and FLAGS.keep_saved_model:
-    # Save the saved model alongside the other compilation artifacts.
-
     # Create a saved model for these target backends to avoid a race condition
     # when running a test suite.
     # TODO(meadowlark): Remove this once we have a TfLiteCompiledModule.
-    sm_path = os.path.join(artifacts_dir,
-                           f"saved_model__{backends_to_str(target_backends)}")
+    sm_path = os.path.join(artifacts_dir, f"saved_model__{backends_string}")
     tf.saved_model.save(tf_module, sm_path, options=options)
     return _compile_from_path(sm_path)
   else:
@@ -223,11 +234,11 @@
       set_random_seed()
       self._module_blob = compile_tf_module(
           tf_module=module_class(),
-          target_backends=backend_info.iree_compiler_targets,
+          backend_infos=[backend_info],
           exported_names=exported_names,
           artifacts_dir=artifacts_dir)
       self._module = rt.VmModule.from_flatbuffer(self._module_blob)
-      self._config = rt.Config(driver_name=backend_info.iree_driver)
+      self._config = rt.Config(driver_name=backend_info.driver)
     else:
       # Called from self.create_reinitialized()
       self._module_blob, self._module, self._config = _create_reinitialized_args
@@ -314,6 +325,8 @@
     self._f = f
 
   def _convert_to_numpy(self, tensor):
+    if not isinstance(tensor, tf.Tensor):
+      return tensor
     result = tensor.numpy()
     if np.isscalar(result):
       # convert_to_tensor isn't reversible via .numpy()
@@ -329,50 +342,63 @@
     if not isinstance(results, tuple):
       results = (results,)
     return tf.nest.map_structure(
-        lambda t: self._convert_to_numpy(t) if isinstance(t, tf.Tensor) else t,
-        *results,
-        check_types=False)
+        self._convert_to_numpy, *results, check_types=False)
 
 
-class BackendInfo(
-    collections.namedtuple(
-        "BackendInfo",
-        ["name", "CompiledModule", "iree_driver", "iree_compiler_targets"])):
-  """Info object describing a backend."""
+class BackendInfo:
 
-  # All BackendInfo entries by name.
-  ALL = {}
+  _name_to_info = {
+      "tf": {
+          "compiled_module_class": TfCompiledModule,
+          "driver": None,
+          "compiler_targets": None,
+      },
+      "iree_vmla": {
+          "compiled_module_class": IreeCompiledModule,
+          "driver": "vmla",
+          "compiler_targets": ["vmla"]
+      },
+      "iree_llvmjit": {
+          "compiled_module_class": IreeCompiledModule,
+          "driver": "llvm",
+          "compiler_targets": ["llvm-ir"]
+      },
+      "iree_vulkan": {
+          "compiled_module_class": IreeCompiledModule,
+          "driver": "vulkan",
+          "compiler_targets": ["vulkan-*"]
+      },
+  }
+
+  def __init__(self, backend_name, artifact_name=None):
+    """Contains information for compiling the specified backend.
+
+    Args:
+      backend_name: a str specifying which backend to use. Should be one of
+        'tf', 'iree_vmla', 'iree_llvmjit', 'iree_vulkan'.
+      artifact_name: an optional str specifying what name to use when saving
+        compiled artifacts.
+
+    Raises:
+      KeyError: if backend_name is not one of ['tf', 'iree_vmla',
+      'iree_llvmjit', 'iree_vulkan'].
+    """
+    if backend_name not in self._name_to_info:
+      raise KeyError(
+          "Expected backend_name to be one of "
+          f"{list(self._name_to_info.keys())} but got '{backend_name}'.")
+    info = self._name_to_info[backend_name]
+    self._compiled_module_class = info["compiled_module_class"]
+    self.driver = info["driver"]
+    self.compiler_targets = info["compiler_targets"]
+    self.name = backend_name if artifact_name is None else artifact_name
+
+  def compile(self, module, exported_names=(), artifacts_dir=None):
+    """Creates a `CompiledModule` for this backend."""
+    return self._compiled_module_class(module, self, exported_names,
+                                       artifacts_dir)
 
   @classmethod
-  def add(cls, **kwargs):
-    backend_info = cls(**kwargs)
-    cls.ALL[backend_info.name] = backend_info
-
-
-BackendInfo.add(
-    name="tf",
-    CompiledModule=TfCompiledModule,
-    iree_driver=None,
-    iree_compiler_targets=None)
-# tf_also is used for checking test consistency
-# to catch any initialization/randomization issues between model runs
-BackendInfo.add(
-    name="tf_also",
-    CompiledModule=TfCompiledModule,
-    iree_driver=None,
-    iree_compiler_targets=None)
-BackendInfo.add(
-    name="iree_vmla",
-    CompiledModule=IreeCompiledModule,
-    iree_driver="vmla",
-    iree_compiler_targets=["vmla"])
-BackendInfo.add(
-    name="iree_vulkan",
-    CompiledModule=IreeCompiledModule,
-    iree_driver="vulkan",
-    iree_compiler_targets=["vulkan-*"])
-BackendInfo.add(
-    name="iree_llvmjit",
-    CompiledModule=IreeCompiledModule,
-    iree_driver="llvm",
-    iree_compiler_targets=["llvm-ir"])
+  def get_all_backends(cls):
+    """Returns a list of all BackendInfo configurations."""
+    return [BackendInfo(backend_name) for backend_name in cls._name_to_info]
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index cde0a08..aa1df8e 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -50,25 +50,27 @@
   @parameterized.named_parameters([
       {
           'testcase_name': 'single_backend',
-          'target_backends': ['vmla'],
+          'backend_infos': [tf_utils.BackendInfo('iree_vmla')],
       },
       {
-          'testcase_name': 'multiple_backends',
-          'target_backends': ['vmla', 'llvm-ir'],
+          'testcase_name':
+              'multiple_backends',
+          'backend_infos': [
+              tf_utils.BackendInfo('iree_vmla'),
+              tf_utils.BackendInfo('iree_llvmjit')
+          ],
       },
   ])
-  def test_artifact_saving(self, target_backends):
+  def test_artifact_saving(self, backend_infos):
     with tempfile.TemporaryDirectory() as artifacts_dir:
       tf_module = ConstantModule()
       iree_compiled_module = tf_utils.compile_tf_module(
-          tf_module,
-          target_backends=target_backends,
-          artifacts_dir=artifacts_dir)
+          tf_module, backend_infos=backend_infos, artifacts_dir=artifacts_dir)
 
       artifacts_to_check = [
           'tf_input.mlir',
           'iree_input.mlir',
-          f'compiled__{tf_utils.backends_to_str(target_backends)}.vmfb',
+          f'compiled__{tf_utils.backends_to_str(backend_infos)}.vmfb',
       ]
       for artifact in artifacts_to_check:
         artifact_path = os.path.join(artifacts_dir, artifact)
@@ -86,8 +88,8 @@
       },
   ])
   def test_unaltered_state(self, backend_name):
-    backend_info = tf_utils.BackendInfo.ALL[backend_name]
-    module = backend_info.CompiledModule(StatefulCountingModule, backend_info)
+    backend_info = tf_utils.BackendInfo(backend_name)
+    module = backend_info.compile(StatefulCountingModule)
 
     # Test that incrementing works properly.
     self.assertEqual([0.], module.get_count())
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index a3346e4..224b162 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -117,7 +117,7 @@
 iree_e2e_test_suite(
     name = "e2e_tests",
     backends_to_srcs = {
-        "tf_also": TF_PASSING,
+        "tf": TF_PASSING,
         "iree_vmla": VMLA_PASSING,
         "iree_llvmjit": LLVM_PASSING,
         "iree_vulkan": VULKAN_PASSING,
@@ -154,7 +154,7 @@
     # TODO(#2082): `linspace_test.py` fails in the `bazel-tensorflow` image.
     name = "linspace_tests",
     backends_to_srcs = {
-        "tf_also": ["linspace_test.py"],
+        "tf": ["linspace_test.py"],
         "iree_vmla": ["linspace_test.py"],
     },
     reference_backend = "tf",
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index a213bda..300cadf 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -40,7 +40,7 @@
 from pyiree.tf.support import tf_utils
 vmla_module = tf_utils.IreeCompiledModule(
     module_class=KerasTFModuleClass,
-    backend_info=tf_utils.BackendInfo.ALL['iree_vmla'],
+    backend_info=tf_utils.BackendInfo('iree_vmla'),
     exported_names=['predict'])
 vmla_module.predict(...)
 ```
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 7d1f1d3..f622d49 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -122,7 +122,7 @@
 iree_e2e_test_suite(
     name = "keras_tests",
     backends_to_srcs = {
-        "tf_also": TF_PASSING,
+        "tf": TF_PASSING,
         "iree_vmla": VMLA_PASSING,
         "iree_llvmjit": LLVM_PASSING,
         "iree_vulkan": VULKAN_PASSING,
@@ -154,7 +154,7 @@
 iree_vision_test_suite(
     name = "vision_internal_tests",
     backends = [
-        "tf_also",
+        "tf",
         "iree_vmla",
         "iree_llvmjit",
         "iree_vulkan",
@@ -171,7 +171,7 @@
 iree_vision_test_suite(
     name = "vision_external_tests",
     backends = [
-        "tf_also",
+        "tf",
         "iree_vmla",
         "iree_llvmjit",
         "iree_vulkan",
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index e9cf397..8f0a5f8 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -24,10 +24,10 @@
 import subprocess
 
 REFERENCE_BACKEND = 'tf'
-# Assumes that tests are expanded for the tf_also, iree_vmla, iree_llvmjit and
+# Assumes that tests are expanded for the tf, iree_vmla, iree_llvmjit and
 # iree_vulkan backends.
 BACKENDS_TO_TITLES = collections.OrderedDict([
-    ('tf_also', 'tensorflow'),
+    ('tf', 'tensorflow'),
     ('iree_vmla', 'vmla'),
     ('iree_llvmjit', 'llvm-ir'),
     ('iree_vulkan', 'vulkan-spirv'),