Generalize the backend select in the test driver.

* Verified that running with and without works: IREE_TEST_BACKENDS="tf,iree_interpreter,iree_vulkan"

PiperOrigin-RevId: 281190129
diff --git a/bindings/python/pyiree/tf_interop/test_utils.py b/bindings/python/pyiree/tf_interop/test_utils.py
index 2f44ef8..d198731 100644
--- a/bindings/python/pyiree/tf_interop/test_utils.py
+++ b/bindings/python/pyiree/tf_interop/test_utils.py
@@ -29,11 +29,12 @@
 import tensorflow.compat.v2 as tf
 
 
-def save_and_compile_tf_module(tf_module):
+def save_and_compile_tf_module(tf_module, target_backends=()):
   with tempfile.TemporaryDirectory() as sm_path:
     options = tf.saved_model.SaveOptions(save_debug_info=True)
     tf.saved_model.save(tf_module, sm_path, options=options)
-    return compiler.tf_compile_saved_model(sm_path)
+    return compiler.tf_compile_saved_model(
+        sm_path, target_backends=target_backends)
 
 
 def dump_iree_module(m):
@@ -48,30 +49,17 @@
     i += 1
 
 
-def get_default_test_backends():
-  backends_env = os.environ.get("IREE_TEST_BACKENDS")
-  if backends_env:
-    return backends_env.split(",")
-  else:
-    return ("tf", "iree_interpreter")
-
-
 class CompiledModule(object):
   """Base class for per-backend compiled module facade."""
 
-  def __init__(self, ctor, backend_name):
+  def __init__(self, ctor, backend):
     self._ctor = ctor
-    self._backend_name = backend_name
+    self._backend = backend
 
   @staticmethod
-  def create(ctor, backend_name):
-    if backend_name == "tf":
-      return TfCompiledModule(ctor, backend_name)
-    elif backend_name.startswith("iree_"):
-      return IreeCompiledModule(ctor, backend_name)
-    else:
-      raise ValueError("Unrecognized @compile_modules backend: '%s'" %
-                       (backend_name,))
+  def create(ctor, backend):
+    compiled_module_class = backend.CompiledModule
+    return compiled_module_class(ctor, backend)
 
   @property
   def ctor(self):
@@ -133,29 +121,28 @@
 class IreeCompiledModule(CompiledModule):
   """Iree compiled module."""
 
-  def __init__(self, ctor, backend_name):
-    super().__init__(ctor, backend_name)
-    self._iree_module_blob = save_and_compile_tf_module(ctor())
+  def __init__(self, ctor, backend):
+    super().__init__(ctor, backend)
+    self._iree_module_blob = save_and_compile_tf_module(
+        ctor(), target_backends=backend.iree_compiler_targets)
     self._iree_module = binding.vm.create_module_from_blob(
         self._iree_module_blob)
 
   def instantiate(self):
-    return _IreeModuleInstance(self._backend_name, self._iree_module_blob,
+    return _IreeModuleInstance(self._backend, self._iree_module_blob,
                                self._iree_module)
 
 
 class _IreeModuleInstance(object):
   """An instance of an IREE module."""
 
-  def __init__(self, backend_name, iree_module_blob, iree_module):
-    self._backend_name = backend_name
+  def __init__(self, backend, iree_module_blob, iree_module):
+    self._backend = backend
     self._iree_module_blob = iree_module_blob
     self._iree_module = iree_module
 
-    # TODO(laurenzo): This driver name matching needs to be made more robust.
-    driver_name = backend_name.split("_")[-1]
     self._policy = binding.rt.Policy()
-    instance = binding.rt.Instance(driver_name=driver_name)
+    instance = binding.rt.Instance(driver_name=backend.iree_driver)
     self._context = binding.rt.Context(instance=instance, policy=self._policy)
     self._context.register_module(self._iree_module)
 
@@ -364,6 +351,55 @@
   return decorator
 
 
+class BackendInfo(
+    collections.namedtuple(
+        "BackendInfo",
+        ["name", "CompiledModule", "iree_driver", "iree_compiler_targets"])):
+  """Info object describing a backend."""
+
+  # All BackendInfo entries by name.
+  ALL = {}
+
+  @classmethod
+  def add(cls, **kwargs):
+    backend_info = cls(**kwargs)
+    cls.ALL[backend_info.name] = backend_info
+
+
+BackendInfo.add(
+    name="tf",
+    CompiledModule=TfCompiledModule,
+    iree_driver=None,
+    iree_compiler_targets=None)
+BackendInfo.add(
+    name="iree_interpreter",
+    CompiledModule=IreeCompiledModule,
+    iree_driver="interpreter",
+    iree_compiler_targets=["interpreter-*"])
+BackendInfo.add(
+    name="iree_vulkan",
+    CompiledModule=IreeCompiledModule,
+    iree_driver="vulkan",
+    iree_compiler_targets=["interpreter-*", "vulkan-*"])
+
+
+def get_default_test_backends():
+  """Gets the default sequence of BackendInfo instances to test against."""
+
+  backends_env = os.environ.get("IREE_TEST_BACKENDS")
+  if backends_env:
+    backends = []
+    for backend_name in backends_env.split(","):
+      try:
+        backends.append(BackendInfo.ALL[backend_name])
+      except IndexError:
+        raise ValueError("In 'IREE_TEST_BACKENDS' env var, unexpected name %s" %
+                         (backend_name))
+    return backends
+  else:
+    return BackendInfo.ALL["tf"], BackendInfo.ALL["iree_interpreter"]
+
+
 class SavedModelTestCase(tf.test.TestCase):
   """Tests against a SavedModel."""
 
@@ -390,7 +426,7 @@
         if backends is None:
           backends = get_default_test_backends()
         cls.compiled_modules[name] = dict([
-            (backend, CompiledModule.create(ctor, backend))
+            (backend.name, CompiledModule.create(ctor, backend))
             for backend in backends
         ])
 
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 7372177..bf0163d 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -47,7 +47,7 @@
 bazel test simple_arithmetic_test --test_output=streamed
 
 # Run tests with an altered list of backends.
-bazel test ... --test_env=IREE_TEST_BACKENDS=tf,iree.interpreter,iree.vulkan \
+bazel test ... --test_env=IREE_TEST_BACKENDS=tf,iree_interpreter,iree_vulkan \
     --test_output=errors
 ```