Use global variable to avoid recompilation (#3372)
The `TestCase` class calls `__init__` before each unittest plus one
additional time. With our current implementation this meant that we
need to re-compile each model at least once for each backend,
which greatly increased the runtime of tests like MobileBERT.
diff --git a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
index bcc629e..00ea2c2 100644
--- a/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
+++ b/integrations/tensorflow/bindings/python/pyiree/tf/support/tf_test_utils.py
@@ -585,6 +585,15 @@
Modules = collections.namedtuple("Modules",
["ref_module", "tar_modules", "artifacts_dir"])
+# We have to use a global variable to store the compiled modules so that we can
+# avoid recompilation. This is because the TestCase class resets it's entire
+# state and calls __init__ before each unittest. It also calls __init__ one
+# additional time before that for good measure, which means without storing the
+# modules somewhere else we would have to compile each of them at least twice.
+# We can't store the modules on the class itself via setUpClass because of #2900
+global _global_modules
+_global_modules = None
+
def compile_tf_module(
module_class: Type[tf.Module], exported_names: Sequence[str] = ()
@@ -601,6 +610,9 @@
A 'Modules' namedtuple containing the reference module, target modules and
artifacts directory.
"""
+ global _global_modules
+ if _global_modules is not None:
+ return _global_modules
# Setup the directory for saving compilation artifacts and traces.
artifacts_dir = _setup_artifacts_dir(module_class.__name__)
@@ -617,7 +629,8 @@
tar_modules = [
compile_backend(backend_info) for backend_info in tar_backend_infos
]
- return Modules(ref_module, tar_modules, artifacts_dir)
+ _global_modules = Modules(ref_module, tar_modules, artifacts_dir)
+ return _global_modules
def compile_tf_signature_def_saved_model(saved_model_dir: str,
@@ -641,6 +654,9 @@
A 'Modules' namedtuple containing the reference module, target modules and
artifacts directory.
"""
+ global _global_modules
+ if _global_modules is not None:
+ return _global_modules
# Setup the directory for saving compilation artifacts and traces.
artifacts_dir = _setup_artifacts_dir(module_name)
@@ -659,7 +675,8 @@
tar_modules = [
compile_backend(backend_info) for backend_info in tar_backend_infos
]
- return Modules(ref_module, tar_modules, artifacts_dir)
+ _global_modules = Modules(ref_module, tar_modules, artifacts_dir)
+ return _global_modules
class TracedModuleTestCase(tf.test.TestCase):