Generalize the backend select in the test driver.
* Verified that running with and without works: IREE_TEST_BACKENDS="tf,iree_interpreter,iree_vulkan"
PiperOrigin-RevId: 281190129
diff --git a/bindings/python/pyiree/tf_interop/test_utils.py b/bindings/python/pyiree/tf_interop/test_utils.py
index 2f44ef8..d198731 100644
--- a/bindings/python/pyiree/tf_interop/test_utils.py
+++ b/bindings/python/pyiree/tf_interop/test_utils.py
@@ -29,11 +29,12 @@
import tensorflow.compat.v2 as tf
-def save_and_compile_tf_module(tf_module):
+def save_and_compile_tf_module(tf_module, target_backends=()):
with tempfile.TemporaryDirectory() as sm_path:
options = tf.saved_model.SaveOptions(save_debug_info=True)
tf.saved_model.save(tf_module, sm_path, options=options)
- return compiler.tf_compile_saved_model(sm_path)
+ return compiler.tf_compile_saved_model(
+ sm_path, target_backends=target_backends)
def dump_iree_module(m):
@@ -48,30 +49,17 @@
i += 1
-def get_default_test_backends():
- backends_env = os.environ.get("IREE_TEST_BACKENDS")
- if backends_env:
- return backends_env.split(",")
- else:
- return ("tf", "iree_interpreter")
-
-
class CompiledModule(object):
"""Base class for per-backend compiled module facade."""
- def __init__(self, ctor, backend_name):
+ def __init__(self, ctor, backend):
self._ctor = ctor
- self._backend_name = backend_name
+ self._backend = backend
@staticmethod
- def create(ctor, backend_name):
- if backend_name == "tf":
- return TfCompiledModule(ctor, backend_name)
- elif backend_name.startswith("iree_"):
- return IreeCompiledModule(ctor, backend_name)
- else:
- raise ValueError("Unrecognized @compile_modules backend: '%s'" %
- (backend_name,))
+ def create(ctor, backend):
+ compiled_module_class = backend.CompiledModule
+ return compiled_module_class(ctor, backend)
@property
def ctor(self):
@@ -133,29 +121,28 @@
class IreeCompiledModule(CompiledModule):
"""Iree compiled module."""
- def __init__(self, ctor, backend_name):
- super().__init__(ctor, backend_name)
- self._iree_module_blob = save_and_compile_tf_module(ctor())
+ def __init__(self, ctor, backend):
+ super().__init__(ctor, backend)
+ self._iree_module_blob = save_and_compile_tf_module(
+ ctor(), target_backends=backend.iree_compiler_targets)
self._iree_module = binding.vm.create_module_from_blob(
self._iree_module_blob)
def instantiate(self):
- return _IreeModuleInstance(self._backend_name, self._iree_module_blob,
+ return _IreeModuleInstance(self._backend, self._iree_module_blob,
self._iree_module)
class _IreeModuleInstance(object):
"""An instance of an IREE module."""
- def __init__(self, backend_name, iree_module_blob, iree_module):
- self._backend_name = backend_name
+ def __init__(self, backend, iree_module_blob, iree_module):
+ self._backend = backend
self._iree_module_blob = iree_module_blob
self._iree_module = iree_module
- # TODO(laurenzo): This driver name matching needs to be made more robust.
- driver_name = backend_name.split("_")[-1]
self._policy = binding.rt.Policy()
- instance = binding.rt.Instance(driver_name=driver_name)
+ instance = binding.rt.Instance(driver_name=backend.iree_driver)
self._context = binding.rt.Context(instance=instance, policy=self._policy)
self._context.register_module(self._iree_module)
@@ -364,6 +351,55 @@
return decorator
+class BackendInfo(
+ collections.namedtuple(
+ "BackendInfo",
+ ["name", "CompiledModule", "iree_driver", "iree_compiler_targets"])):
+ """Info object describing a backend."""
+
+ # All BackendInfo entries by name.
+ ALL = {}
+
+ @classmethod
+ def add(cls, **kwargs):
+ backend_info = cls(**kwargs)
+ cls.ALL[backend_info.name] = backend_info
+
+
+BackendInfo.add(
+ name="tf",
+ CompiledModule=TfCompiledModule,
+ iree_driver=None,
+ iree_compiler_targets=None)
+BackendInfo.add(
+ name="iree_interpreter",
+ CompiledModule=IreeCompiledModule,
+ iree_driver="interpreter",
+ iree_compiler_targets=["interpreter-*"])
+BackendInfo.add(
+ name="iree_vulkan",
+ CompiledModule=IreeCompiledModule,
+ iree_driver="vulkan",
+ iree_compiler_targets=["interpreter-*", "vulkan-*"])
+
+
+def get_default_test_backends():
+ """Gets the default sequence of BackendInfo instances to test against."""
+
+ backends_env = os.environ.get("IREE_TEST_BACKENDS")
+ if backends_env:
+ backends = []
+ for backend_name in backends_env.split(","):
+ try:
+ backends.append(BackendInfo.ALL[backend_name])
+ except IndexError:
+ raise ValueError("In 'IREE_TEST_BACKENDS' env var, unexpected name %s" %
+ (backend_name))
+ return backends
+ else:
+ return BackendInfo.ALL["tf"], BackendInfo.ALL["iree_interpreter"]
+
+
class SavedModelTestCase(tf.test.TestCase):
"""Tests against a SavedModel."""
@@ -390,7 +426,7 @@
if backends is None:
backends = get_default_test_backends()
cls.compiled_modules[name] = dict([
- (backend, CompiledModule.create(ctor, backend))
+ (backend.name, CompiledModule.create(ctor, backend))
for backend in backends
])
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index 7372177..bf0163d 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -47,7 +47,7 @@
bazel test simple_arithmetic_test --test_output=streamed
# Run tests with an altered list of backends.
-bazel test ... --test_env=IREE_TEST_BACKENDS=tf,iree.interpreter,iree.vulkan \
+bazel test ... --test_env=IREE_TEST_BACKENDS=tf,iree_interpreter,iree_vulkan \
--test_output=errors
```