Update test suites in get_e2e_artifacts.py (#3583)
diff --git a/scripts/get_e2e_artifacts.py b/scripts/get_e2e_artifacts.py
index 5cb6f8e..7ed2691 100755
--- a/scripts/get_e2e_artifacts.py
+++ b/scripts/get_e2e_artifacts.py
@@ -15,8 +15,12 @@
# limitations under the License.
"""Runs all E2E TensorFlow tests and extracts their benchmarking artifacts.
-Example usage:
- python3 get_e2e_artifacts.py
+Example usages:
+ # Run all test suites and collect their artifacts:
+ python3 ./scripts/get_e2e_artifacts.py
+
+ # Run the e2e_tests test suite and collect its artifacts:
+ python3 ./scripts/get_e2e_artifacts.py --test_suites=e2e_tests
"""
import fileinput
@@ -39,8 +43,10 @@
'//integrations/tensorflow/e2e:mobile_bert_squad_tests',
'keras_tests':
'//integrations/tensorflow/e2e/keras:keras_tests',
- 'vision_external_tests':
- '//integrations/tensorflow/e2e/keras:vision_external_tests',
+ 'imagenet_external_tests':
+ '//integrations/tensorflow/e2e/keras:imagenet_external_tests',
+ 'slim_vision_tests':
+ '//integrations/tensorflow/e2e/slim_vision_models:slim_vision_tests',
}
SUITES_HELP = [f'`{name}`' for name in SUITE_NAME_TO_TARGET]
SUITES_HELP = f'{", ".join(SUITES_HELP[:-1])} and {SUITES_HELP[-1]}'
@@ -118,6 +124,11 @@
paths_to_tests: Dict[str, str]):
"""Unzips all of the benchmarking artifacts for a given test and backend."""
outputs = os.path.join(test_path, 'test.outputs', 'outputs.zip')
+ if FLAGS.dry_run and not os.path.exists(outputs):
+ # The artifacts may or may not be present on disk during a dry run. If they
+ # are then we want to collision check them, but if they aren't that's fine.
+ return
+
archive = zipfile.ZipFile(outputs)
# Filter out directory names.
filenames = [name for name in archive.namelist() if name[-1] != os.sep]
@@ -139,17 +150,21 @@
# Convert test suite shorthands to full test suite targets.
test_suites = [SUITE_NAME_TO_TARGET[suite] for suite in FLAGS.test_suites]
+ if FLAGS.run_test_suites:
+ # Use bazel test to execute all of the test suites in parallel.
+ command = [
+ 'bazel', 'test', *test_suites, '--color=yes',
+ '--test_arg=--get_saved_model'
+ ]
+ print(f'Running: `{" ".join(command)}`')
+ if not FLAGS.dry_run:
+ subprocess.check_call(command)
+ print()
+
written_paths = set()
paths_to_tests = dict()
for test_suite in test_suites:
- if FLAGS.run_test_suites and not FLAGS.dry_run:
- subprocess.check_call([
- 'bazel', 'test', test_suite, '--color=yes',
- '--test_arg=--get_saved_model'
- ])
- print()
-
# Extract all of the artifacts for this test suite.
test_paths, test_names = get_test_paths_and_names(test_suite)
for i, (test_path, test_name) in enumerate(zip(test_paths, test_names)):