| #!/usr/bin/env python3 |
| # Copyright 2022 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| import pathlib |
| import unittest |
| import tempfile |
| |
| from common import benchmark_config, benchmark_definition, common_arguments |
| |
| |
| class BenchmarkConfigTest(unittest.TestCase): |
| def setUp(self): |
| self._tmp_dir_manager = tempfile.TemporaryDirectory() |
| self.tmp_dir = pathlib.Path(self._tmp_dir_manager.name).resolve() |
| self._build_dir_manager = tempfile.TemporaryDirectory() |
| self.build_dir = pathlib.Path(self._build_dir_manager.name).resolve() |
| self.e2e_test_artifacts_dir = self.build_dir / "e2e_test_artifacts" |
| self.e2e_test_artifacts_dir.mkdir() |
| self.normal_tool_dir = self.build_dir / "normal_tool" |
| self.normal_tool_dir.mkdir() |
| self.traced_tool_dir = self.build_dir / "traced_tool" |
| self.traced_tool_dir.mkdir() |
| self.trace_capture_tool = self.build_dir / "tracy_capture" |
| # Create capture tool with executable file mode. |
| self.trace_capture_tool.touch(mode=0o755) |
| self.execution_config = self.build_dir / "execution_config.json" |
| self.execution_config.touch() |
| |
| def tearDown(self): |
| self._build_dir_manager.cleanup() |
| self._tmp_dir_manager.cleanup() |
| |
| def test_build_from_args(self): |
| args = common_arguments.Parser().parse_args( |
| [ |
| f"--tmp_dir={self.tmp_dir}", |
| f"--normal_benchmark_tool_dir={self.normal_tool_dir}", |
| f"--traced_benchmark_tool_dir={self.traced_tool_dir}", |
| f"--trace_capture_tool={self.trace_capture_tool}", |
| f"--capture_tarball=capture.tar", |
| f"--driver_filter_regex=a", |
| f"--model_name_regex=b", |
| f"--mode_regex=c", |
| f"--keep_going", |
| f"--benchmark_min_time=10", |
| f"--compatible_only", |
| f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", |
| f"--execution_benchmark_config={self.execution_config}", |
| "--target_device=test", |
| "--verify", |
| ] |
| ) |
| |
| config = benchmark_config.BenchmarkConfig.build_from_args( |
| args=args, git_commit_hash="abcd" |
| ) |
| |
| per_commit_tmp_dir = self.tmp_dir / "abcd" |
| expected_trace_capture_config = benchmark_config.TraceCaptureConfig( |
| traced_benchmark_tool_dir=self.traced_tool_dir, |
| trace_capture_tool=pathlib.Path(self.trace_capture_tool).resolve(), |
| capture_tarball=pathlib.Path("capture.tar").resolve(), |
| capture_tmp_dir=per_commit_tmp_dir / "captures", |
| ) |
| expected_config = benchmark_config.BenchmarkConfig( |
| tmp_dir=per_commit_tmp_dir, |
| root_benchmark_dir=benchmark_definition.ResourceLocation.build_local_path( |
| self.e2e_test_artifacts_dir |
| ), |
| benchmark_results_dir=per_commit_tmp_dir / "benchmark-results", |
| git_commit_hash="abcd", |
| normal_benchmark_tool_dir=self.normal_tool_dir, |
| trace_capture_config=expected_trace_capture_config, |
| driver_filter="a", |
| model_name_filter="b", |
| mode_filter="c", |
| keep_going=True, |
| benchmark_min_time=10, |
| use_compatible_filter=True, |
| verify=True, |
| ) |
| self.assertEqual(config, expected_config) |
| |
| def test_build_from_args_benchmark_only(self): |
| args = common_arguments.Parser().parse_args( |
| [ |
| f"--tmp_dir={self.tmp_dir}", |
| f"--normal_benchmark_tool_dir={self.normal_tool_dir}", |
| f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", |
| f"--execution_benchmark_config={self.execution_config}", |
| "--target_device=test", |
| ] |
| ) |
| |
| config = benchmark_config.BenchmarkConfig.build_from_args( |
| args=args, git_commit_hash="abcd" |
| ) |
| |
| self.assertIsNone(config.trace_capture_config) |
| |
| def test_build_from_args_with_test_artifacts_dir_url(self): |
| args = common_arguments.Parser().parse_args( |
| [ |
| f"--tmp_dir={self.tmp_dir}", |
| f"--normal_benchmark_tool_dir={self.normal_tool_dir}", |
| f"--e2e_test_artifacts_dir=https://example.com/testdata", |
| f"--execution_benchmark_config={self.execution_config}", |
| "--target_device=test", |
| ] |
| ) |
| |
| config = benchmark_config.BenchmarkConfig.build_from_args( |
| args=args, git_commit_hash="abcd" |
| ) |
| |
| self.assertEqual( |
| config.root_benchmark_dir.get_url(), "https://example.com/testdata" |
| ) |
| |
| def test_build_from_args_invalid_capture_args(self): |
| args = common_arguments.Parser().parse_args( |
| [ |
| f"--tmp_dir={self.tmp_dir}", |
| f"--normal_benchmark_tool_dir={self.normal_tool_dir}", |
| f"--traced_benchmark_tool_dir={self.traced_tool_dir}", |
| f"--e2e_test_artifacts_dir={self.e2e_test_artifacts_dir}", |
| f"--execution_benchmark_config={self.execution_config}", |
| "--target_device=test", |
| ] |
| ) |
| |
| self.assertRaises( |
| ValueError, |
| lambda: benchmark_config.BenchmarkConfig.build_from_args( |
| args=args, git_commit_hash="abcd" |
| ), |
| ) |
| |
| |
| if __name__ == "__main__": |
| unittest.main() |