Refactor BackendInfo, remove tf_also, fully support duplicate backends (#2738)
`BackendInfo` is refactored from a wrapper around a set of `NamedTuple`s into it's own class. This allows for the name attribute to be overridden, which is necessary for supporting saving artifacts for duplicate backends (e.g. two copies of `"iree_vmla"`). It also provides a cleaner API for compiling modules and getting all `BackendInfo` configurations.
Changes this allows:
- Compilation artifacts and traces from duplicate target backends will not overwrite each other. Instead artifacts for each backend are saved with an index (e.g. `trace__iree_vmla_0.txt`, `trace__iree_vmla_1.txt` and so on).
- The reference backend will have `_ref` appended to it's name (e.g. `trace__tf_ref.txt` or `trace__iree_vmla_ref.txt`).
- `tf_also` is removed.
Example artifacts:
```
SimpleArithmeticModule
├── compiled__iree_llvmjit.vmfb
├── compiled__iree_vmla.vmfb
├── compiled__iree_vulkan.vmfb
├── iree_input.mlir
├── tf_input.mlir
└── traces
├── simple_matmul__iree_llvmjit.txt
├── simple_matmul__iree_vmla.txt
├── simple_matmul__iree_vulkan.txt
├── simple_matmul__tf_ref.txt
├── simple_matmul__tf.txt
├── simple_mul__iree_llvmjit.txt
├── simple_mul__iree_vmla.txt
├── simple_mul__iree_vulkan.txt
├── simple_mul__tf_ref.txt
└── simple_mul__tf.txt
```
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index 2a4de30..b3235b3 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -67,17 +67,23 @@
return artifacts_dir
-def _parse_target_backends(target_backends):
- """Decodes a comma-delimited string of backends into BackendInfo objects."""
- backends = []
- for backend_name in target_backends.split(","):
- if backend_name not in tf_utils.BackendInfo.ALL.keys():
- raise ValueError(
- "Invalid backend specification string '{}', unexpected name '{}';"
- " valid names are '{}'".format(target_backends, backend_name,
- tf_utils.BackendInfo.ALL.keys()))
- backends.append(tf_utils.BackendInfo.ALL[backend_name])
- return backends
+def _parse_target_backends():
+ """Decodes --target_backends and creates unique names for their artifacts."""
+ backend_names = FLAGS.target_backends.split(",")
+ backend_to_index = {k: 0 for k in backend_names if backend_names.count(k) > 1}
+ artifact_names = []
+
+ # If there are multiple copies of the same backend_name, index them. e.g.
+ # backend_names = ["tf", "iree_vmla", "tf"]
+ # --> artifact_names = ["tf_0", "iree_vmla", "tf_1"]
+ for backend_name in backend_names:
+ if backend_name in backend_to_index:
+ artifact_names.append(f"{backend_name}_{backend_to_index[backend_name]}")
+ backend_to_index[backend_name] += 1
+ else:
+ artifact_names.append(backend_name)
+
+ return backend_names, artifact_names
def get_target_backends():
@@ -91,10 +97,14 @@
"""
if FLAGS.target_backends is not None:
logging.info("Using backends from command line: %s", FLAGS.target_backends)
- backends = _parse_target_backends(FLAGS.target_backends)
+ backend_names, names = _parse_target_backends()
+ backends = [
+ tf_utils.BackendInfo(backend, name)
+ for backend, name in zip(backend_names, names)
+ ]
else:
# If no backends are specified, use them all.
- backends = list(tf_utils.BackendInfo.ALL.values())
+ backends = tf_utils.BackendInfo.get_all_backends()
return backends
@@ -406,8 +416,8 @@
@classmethod
def _compile(cls, backend_info):
- return backend_info.CompiledModule(cls._module_class, backend_info,
- cls._exported_names, cls._artifacts_dir)
+ return backend_info.compile(cls._module_class, cls._exported_names,
+ cls._artifacts_dir)
@classmethod
def setUpClass(cls):
@@ -421,25 +431,15 @@
# Setup the directory for saving compilation artifacts and traces.
cls._artifacts_dir = _setup_artifacts_dir(cls._module_class.__name__)
- # Setup crash reproducer for the test.
- crash_reproducer_path = os.path.join(cls._artifacts_dir, "reproducer.mlir")
- compiler.Context.default_crash_reproducer_path = crash_reproducer_path
-
# Create a CompiledModule for the reference backend and each target backend.
- try:
- ref_backend_info = tf_utils.BackendInfo.ALL[FLAGS.reference_backend]
- cls._ref_module = cls._compile(ref_backend_info)
+ ref_backend_info = tf_utils.BackendInfo(FLAGS.reference_backend,
+ f"{FLAGS.reference_backend}_ref")
+ cls._ref_module = cls._compile(ref_backend_info)
- tar_backend_infos = get_target_backends()
- cls._tar_modules = [
- cls._compile(backend_info) for backend_info in tar_backend_infos
- ]
- finally:
- # TODO(meadowlark): Move this into tf_util.compile_tf_module to prevent
- # overwritting `reproducer.mlir`.
- # Disable crash reproducer (to avoid inadvertently overwriting this
- # path if there are multiple TestCases in the same file).
- compiler.Context.default_crash_reproducer_path = None
+ tar_backend_infos = get_target_backends()
+ cls._tar_modules = [
+ cls._compile(backend_info) for backend_info in tar_backend_infos
+ ]
def setUp(self):
# Ran before each unit test.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
index f21521a..a0f24c5 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils_test.py
@@ -112,7 +112,7 @@
module.get_count()
module = tf_utils.TfCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo.ALL['tf'])
+ tf_utils.BackendInfo('tf'))
trace = tf_test_utils.Trace(module, trace_function)
trace_function(tf_test_utils.TracedModule(module, trace))
@@ -135,12 +135,12 @@
module.decrement()
tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo.ALL['tf'])
+ tf_utils.BackendInfo('tf'))
tf_trace = tf_test_utils.Trace(tf_module, tf_function)
tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
vmla_module = tf_utils.IreeCompiledModule(
- StatefulCountingModule, tf_utils.BackendInfo.ALL['iree_vmla'])
+ StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
@@ -156,12 +156,12 @@
module.increment_by(np.array([22.], dtype=np.float32))
tf_module = tf_utils.TfCompiledModule(StatefulCountingModule,
- tf_utils.BackendInfo.ALL['tf'])
+ tf_utils.BackendInfo('tf'))
tf_trace = tf_test_utils.Trace(tf_module, tf_function)
tf_function(tf_test_utils.TracedModule(tf_module, tf_trace))
vmla_module = tf_utils.IreeCompiledModule(
- StatefulCountingModule, tf_utils.BackendInfo.ALL['iree_vmla'])
+ StatefulCountingModule, tf_utils.BackendInfo('iree_vmla'))
vmla_trace = tf_test_utils.Trace(vmla_module, vmla_function)
vmla_function(tf_test_utils.TracedModule(vmla_module, vmla_trace))
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
index f67119e..d966346 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils.py
@@ -16,7 +16,6 @@
# pylint: disable=protected-access
-import collections
import os
import random
import re
@@ -49,14 +48,14 @@
return np.arange(np.prod(shape), dtype=dtype).reshape(shape)
-def backends_to_str(target_backends):
- """Creates a flattened and normalized string representing target_backends."""
- normalized_backends = []
- for backend in target_backends:
+def backends_to_str(backend_infos):
+ """Creates a normalized string representing the provided backends."""
+ normalized_names = []
+ for backend_info in backend_infos:
# Remove unusual characters and ensure names don't end or start in "_".
- backend = re.sub("[^0-9a-zA-Z_]+", "_", backend)
- normalized_backends.append(backend.strip("_"))
- return "__".join(normalized_backends)
+ name = re.sub("[^0-9a-zA-Z_]+", "_", backend_info.name)
+ normalized_names.append(name.strip("_"))
+ return "__".join(normalized_names)
def to_mlir_type(dtype):
@@ -90,7 +89,7 @@
def compile_tf_module(tf_module,
- target_backends=(),
+ backend_infos=(),
exported_names=(),
artifacts_dir=None):
"""Compiles a TensorFlow tf.Module and optionally saves compilation artifacts.
@@ -112,7 +111,7 @@
Args:
tf_module: A tf.Module.
- target_backends: Iterable of string backend names to compile for.
+ backend_infos: Iterable of BackendInfo names to compile for.
exported_names: Iterable of dotted function names to consider for
compilation.
artifacts_dir: An optional string pointing to where compilation artifacts
@@ -124,50 +123,62 @@
def _compile_from_path(sm_path):
"""Helper function for compile_tf_module."""
- # We break up the compilation here so we can save intermediary artifacts.
- compiler_context = compiler.Context()
-
- # Convert the tf_module into raw TF input MLIR.
- compiler_module = compiler.tf_load_saved_model(
- sm_path,
- exported_names=exported_names,
- compiler_context=compiler_context,
- pass_pipeline=())
-
if artifacts_dir is not None:
- tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
- logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
- with open(tf_mlir_path, "w") as f:
- f.write(compiler_module.to_asm())
+ # Set up a crash reproducer for debugging.
+ compiler.Context.default_crash_reproducer_path = os.path.join(
+ artifacts_dir, f"reproducer__{backends_string}.mlir")
+ try:
+ # We break up the compilation here so we can save intermediary artifacts.
+ compiler_context = compiler.Context()
- # Now run the passes manually that tf_load_saved_model would usually do.
- compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
+ # Convert the tf_module into raw TF input MLIR.
+ compiler_module = compiler.tf_load_saved_model(
+ sm_path,
+ exported_names=exported_names,
+ compiler_context=compiler_context,
+ pass_pipeline=())
- if artifacts_dir is not None:
- iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
- logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
- with open(iree_mlir_path, "w") as f:
- f.write(compiler_module.to_asm())
+ if artifacts_dir is not None:
+ tf_mlir_path = os.path.join(artifacts_dir, "tf_input.mlir")
+ logging.info("Saving raw TF input MLIR to: %s", tf_mlir_path)
+ with open(tf_mlir_path, "w") as f:
+ f.write(compiler_module.to_asm())
- compiled_module = compiler_module.compile(target_backends=target_backends)
- if artifacts_dir is not None:
- compiled_name = f"compiled__{backends_to_str(target_backends)}.vmfb"
- compiled_path = os.path.join(artifacts_dir, compiled_name)
- logging.info("Saving compiled IREE module to: %s", compiled_path)
- with open(compiled_path, "wb") as f:
- f.write(compiled_module)
+ # Now run the passes manually that tf_load_saved_model would usually do.
+ compiler_module.run_pass_pipeline(compiler.TF_IMPORT_PASS_PIPELINE)
- return compiled_module
+ if artifacts_dir is not None:
+ iree_mlir_path = os.path.join(artifacts_dir, "iree_input.mlir")
+ logging.info("Saving IREE input MLIR to: %s", iree_mlir_path)
+ with open(iree_mlir_path, "w") as f:
+ f.write(compiler_module.to_asm())
+
+ target_backends = []
+ for backend_info in backend_infos:
+ target_backends.extend(backend_info.compiler_targets)
+ compiled_module = compiler_module.compile(target_backends=target_backends)
+
+ if artifacts_dir is not None:
+ compiled_name = f"compiled__{backends_string}.vmfb"
+ compiled_path = os.path.join(artifacts_dir, compiled_name)
+ logging.info("Saving compiled IREE module to: %s", compiled_path)
+ with open(compiled_path, "wb") as f:
+ f.write(compiled_module)
+
+ return compiled_module
+ except Exception: # pylint: disable=broad-except
+ if artifacts_dir is not None:
+ # Disable the crash reproducer (to avoid inadvertently overwriting it).
+ compiler.Context.default_crash_reproducer_path = None
+ raise
options = tf.saved_model.SaveOptions(save_debug_info=True)
+ backends_string = backends_to_str(backend_infos)
if artifacts_dir is not None and FLAGS.keep_saved_model:
- # Save the saved model alongside the other compilation artifacts.
-
# Create a saved model for these target backends to avoid a race condition
# when running a test suite.
# TODO(meadowlark): Remove this once we have a TfLiteCompiledModule.
- sm_path = os.path.join(artifacts_dir,
- f"saved_model__{backends_to_str(target_backends)}")
+ sm_path = os.path.join(artifacts_dir, f"saved_model__{backends_string}")
tf.saved_model.save(tf_module, sm_path, options=options)
return _compile_from_path(sm_path)
else:
@@ -223,11 +234,11 @@
set_random_seed()
self._module_blob = compile_tf_module(
tf_module=module_class(),
- target_backends=backend_info.iree_compiler_targets,
+ backend_infos=[backend_info],
exported_names=exported_names,
artifacts_dir=artifacts_dir)
self._module = rt.VmModule.from_flatbuffer(self._module_blob)
- self._config = rt.Config(driver_name=backend_info.iree_driver)
+ self._config = rt.Config(driver_name=backend_info.driver)
else:
# Called from self.create_reinitialized()
self._module_blob, self._module, self._config = _create_reinitialized_args
@@ -314,6 +325,8 @@
self._f = f
def _convert_to_numpy(self, tensor):
+ if not isinstance(tensor, tf.Tensor):
+ return tensor
result = tensor.numpy()
if np.isscalar(result):
# convert_to_tensor isn't reversible via .numpy()
@@ -329,50 +342,63 @@
if not isinstance(results, tuple):
results = (results,)
return tf.nest.map_structure(
- lambda t: self._convert_to_numpy(t) if isinstance(t, tf.Tensor) else t,
- *results,
- check_types=False)
+ self._convert_to_numpy, *results, check_types=False)
-class BackendInfo(
- collections.namedtuple(
- "BackendInfo",
- ["name", "CompiledModule", "iree_driver", "iree_compiler_targets"])):
- """Info object describing a backend."""
+class BackendInfo:
- # All BackendInfo entries by name.
- ALL = {}
+ _name_to_info = {
+ "tf": {
+ "compiled_module_class": TfCompiledModule,
+ "driver": None,
+ "compiler_targets": None,
+ },
+ "iree_vmla": {
+ "compiled_module_class": IreeCompiledModule,
+ "driver": "vmla",
+ "compiler_targets": ["vmla"]
+ },
+ "iree_llvmjit": {
+ "compiled_module_class": IreeCompiledModule,
+ "driver": "llvm",
+ "compiler_targets": ["llvm-ir"]
+ },
+ "iree_vulkan": {
+ "compiled_module_class": IreeCompiledModule,
+ "driver": "vulkan",
+ "compiler_targets": ["vulkan-*"]
+ },
+ }
+
+ def __init__(self, backend_name, artifact_name=None):
+ """Contains information for compiling the specified backend.
+
+ Args:
+ backend_name: a str specifying which backend to use. Should be one of
+ 'tf', 'iree_vmla', 'iree_llvmjit', 'iree_vulkan'.
+ artifact_name: an optional str specifying what name to use when saving
+ compiled artifacts.
+
+ Raises:
+ KeyError: if backend_name is not one of ['tf', 'iree_vmla',
+ 'iree_llvmjit', 'iree_vulkan'].
+ """
+ if backend_name not in self._name_to_info:
+ raise KeyError(
+ "Expected backend_name to be one of "
+ f"{list(self._name_to_info.keys())} but got '{backend_name}'.")
+ info = self._name_to_info[backend_name]
+ self._compiled_module_class = info["compiled_module_class"]
+ self.driver = info["driver"]
+ self.compiler_targets = info["compiler_targets"]
+ self.name = backend_name if artifact_name is None else artifact_name
+
+ def compile(self, module, exported_names=(), artifacts_dir=None):
+ """Creates a `CompiledModule` for this backend."""
+ return self._compiled_module_class(module, self, exported_names,
+ artifacts_dir)
@classmethod
- def add(cls, **kwargs):
- backend_info = cls(**kwargs)
- cls.ALL[backend_info.name] = backend_info
-
-
-BackendInfo.add(
- name="tf",
- CompiledModule=TfCompiledModule,
- iree_driver=None,
- iree_compiler_targets=None)
-# tf_also is used for checking test consistency
-# to catch any initialization/randomization issues between model runs
-BackendInfo.add(
- name="tf_also",
- CompiledModule=TfCompiledModule,
- iree_driver=None,
- iree_compiler_targets=None)
-BackendInfo.add(
- name="iree_vmla",
- CompiledModule=IreeCompiledModule,
- iree_driver="vmla",
- iree_compiler_targets=["vmla"])
-BackendInfo.add(
- name="iree_vulkan",
- CompiledModule=IreeCompiledModule,
- iree_driver="vulkan",
- iree_compiler_targets=["vulkan-*"])
-BackendInfo.add(
- name="iree_llvmjit",
- CompiledModule=IreeCompiledModule,
- iree_driver="llvm",
- iree_compiler_targets=["llvm-ir"])
+ def get_all_backends(cls):
+ """Returns a list of all BackendInfo configurations."""
+ return [BackendInfo(backend_name) for backend_name in cls._name_to_info]
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
index cde0a08..aa1df8e 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_utils_test.py
@@ -50,25 +50,27 @@
@parameterized.named_parameters([
{
'testcase_name': 'single_backend',
- 'target_backends': ['vmla'],
+ 'backend_infos': [tf_utils.BackendInfo('iree_vmla')],
},
{
- 'testcase_name': 'multiple_backends',
- 'target_backends': ['vmla', 'llvm-ir'],
+ 'testcase_name':
+ 'multiple_backends',
+ 'backend_infos': [
+ tf_utils.BackendInfo('iree_vmla'),
+ tf_utils.BackendInfo('iree_llvmjit')
+ ],
},
])
- def test_artifact_saving(self, target_backends):
+ def test_artifact_saving(self, backend_infos):
with tempfile.TemporaryDirectory() as artifacts_dir:
tf_module = ConstantModule()
iree_compiled_module = tf_utils.compile_tf_module(
- tf_module,
- target_backends=target_backends,
- artifacts_dir=artifacts_dir)
+ tf_module, backend_infos=backend_infos, artifacts_dir=artifacts_dir)
artifacts_to_check = [
'tf_input.mlir',
'iree_input.mlir',
- f'compiled__{tf_utils.backends_to_str(target_backends)}.vmfb',
+ f'compiled__{tf_utils.backends_to_str(backend_infos)}.vmfb',
]
for artifact in artifacts_to_check:
artifact_path = os.path.join(artifacts_dir, artifact)
@@ -86,8 +88,8 @@
},
])
def test_unaltered_state(self, backend_name):
- backend_info = tf_utils.BackendInfo.ALL[backend_name]
- module = backend_info.CompiledModule(StatefulCountingModule, backend_info)
+ backend_info = tf_utils.BackendInfo(backend_name)
+ module = backend_info.compile(StatefulCountingModule)
# Test that incrementing works properly.
self.assertEqual([0.], module.get_count())
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index a3346e4..224b162 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -117,7 +117,7 @@
iree_e2e_test_suite(
name = "e2e_tests",
backends_to_srcs = {
- "tf_also": TF_PASSING,
+ "tf": TF_PASSING,
"iree_vmla": VMLA_PASSING,
"iree_llvmjit": LLVM_PASSING,
"iree_vulkan": VULKAN_PASSING,
@@ -154,7 +154,7 @@
# TODO(#2082): `linspace_test.py` fails in the `bazel-tensorflow` image.
name = "linspace_tests",
backends_to_srcs = {
- "tf_also": ["linspace_test.py"],
+ "tf": ["linspace_test.py"],
"iree_vmla": ["linspace_test.py"],
},
reference_backend = "tf",
diff --git a/integrations/tensorflow/e2e/README.md b/integrations/tensorflow/e2e/README.md
index a213bda..300cadf 100644
--- a/integrations/tensorflow/e2e/README.md
+++ b/integrations/tensorflow/e2e/README.md
@@ -40,7 +40,7 @@
from pyiree.tf.support import tf_utils
vmla_module = tf_utils.IreeCompiledModule(
module_class=KerasTFModuleClass,
- backend_info=tf_utils.BackendInfo.ALL['iree_vmla'],
+ backend_info=tf_utils.BackendInfo('iree_vmla'),
exported_names=['predict'])
vmla_module.predict(...)
```
diff --git a/integrations/tensorflow/e2e/keras/BUILD b/integrations/tensorflow/e2e/keras/BUILD
index 7d1f1d3..f622d49 100644
--- a/integrations/tensorflow/e2e/keras/BUILD
+++ b/integrations/tensorflow/e2e/keras/BUILD
@@ -122,7 +122,7 @@
iree_e2e_test_suite(
name = "keras_tests",
backends_to_srcs = {
- "tf_also": TF_PASSING,
+ "tf": TF_PASSING,
"iree_vmla": VMLA_PASSING,
"iree_llvmjit": LLVM_PASSING,
"iree_vulkan": VULKAN_PASSING,
@@ -154,7 +154,7 @@
iree_vision_test_suite(
name = "vision_internal_tests",
backends = [
- "tf_also",
+ "tf",
"iree_vmla",
"iree_llvmjit",
"iree_vulkan",
@@ -171,7 +171,7 @@
iree_vision_test_suite(
name = "vision_external_tests",
backends = [
- "tf_also",
+ "tf",
"iree_vmla",
"iree_llvmjit",
"iree_vulkan",
diff --git a/scripts/update_e2e_coverage.py b/scripts/update_e2e_coverage.py
index e9cf397..8f0a5f8 100755
--- a/scripts/update_e2e_coverage.py
+++ b/scripts/update_e2e_coverage.py
@@ -24,10 +24,10 @@
import subprocess
REFERENCE_BACKEND = 'tf'
-# Assumes that tests are expanded for the tf_also, iree_vmla, iree_llvmjit and
+# Assumes that tests are expanded for the tf, iree_vmla, iree_llvmjit and
# iree_vulkan backends.
BACKENDS_TO_TITLES = collections.OrderedDict([
- ('tf_also', 'tensorflow'),
+ ('tf', 'tensorflow'),
('iree_vmla', 'vmla'),
('iree_llvmjit', 'llvm-ir'),
('iree_vulkan', 'vulkan-spirv'),