blob: e432cf53af73c7fb2d716bec6c47be604ed49711 [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import pathlib
import stat
import unittest
import tempfile
import os
from common import benchmark_config, common_arguments
class BenchmarkConfigTest(unittest.TestCase):
def setUp(self):
self._build_dir_manager = tempfile.TemporaryDirectory()
self._tmp_dir_manager = tempfile.TemporaryDirectory()
self.build_dir = pathlib.Path(self._build_dir_manager.name).resolve()
self.tmp_dir = pathlib.Path(self._tmp_dir_manager.name).resolve()
self.normal_tool_dir = self.build_dir / "normal_tool"
self.normal_tool_dir.mkdir()
self.traced_tool_dir = self.build_dir / "traced_tool"
self.traced_tool_dir.mkdir()
self.trace_capture_tool = tempfile.NamedTemporaryFile()
os.chmod(self.trace_capture_tool.name, stat.S_IEXEC)
def tearDown(self):
self.trace_capture_tool.close()
self._tmp_dir_manager.cleanup()
self._build_dir_manager.cleanup()
def test_build_from_args(self):
args = common_arguments.Parser().parse_args([
f"--tmp_dir={self.tmp_dir}",
f"--normal_benchmark_tool_dir={self.normal_tool_dir}",
f"--traced_benchmark_tool_dir={self.traced_tool_dir}",
f"--trace_capture_tool={self.trace_capture_tool.name}",
f"--capture_tarball=capture.tar", f"--driver_filter_regex=a",
f"--model_name_regex=b", f"--mode_regex=c", f"--keep_going",
f"--benchmark_min_time=10",
str(self.build_dir)
])
config = benchmark_config.BenchmarkConfig.build_from_args(
args=args, git_commit_hash="abcd")
per_commit_tmp_dir = self.tmp_dir / "abcd"
expected_trace_capture_config = benchmark_config.TraceCaptureConfig(
traced_benchmark_tool_dir=self.traced_tool_dir,
trace_capture_tool=pathlib.Path(self.trace_capture_tool.name).resolve(),
capture_tarball=pathlib.Path("capture.tar").resolve(),
capture_tmp_dir=per_commit_tmp_dir / "captures")
expected_config = benchmark_config.BenchmarkConfig(
root_benchmark_dir=self.build_dir / "benchmark_suites",
benchmark_results_dir=per_commit_tmp_dir / "benchmark-results",
git_commit_hash="abcd",
normal_benchmark_tool_dir=self.normal_tool_dir,
trace_capture_config=expected_trace_capture_config,
driver_filter="a",
model_name_filter="b",
mode_filter="c",
keep_going=True,
benchmark_min_time=10)
self.assertEqual(config, expected_config)
def test_build_from_args_benchmark_only(self):
args = common_arguments.Parser().parse_args([
f"--tmp_dir={self.tmp_dir}",
f"--normal_benchmark_tool_dir={self.normal_tool_dir}",
str(self.build_dir)
])
config = benchmark_config.BenchmarkConfig.build_from_args(
args=args, git_commit_hash="abcd")
self.assertIsNone(config.trace_capture_config)
def test_build_from_args_invalid_capture_args(self):
args = common_arguments.Parser().parse_args([
f"--tmp_dir={self.tmp_dir}",
f"--normal_benchmark_tool_dir={self.normal_tool_dir}",
f"--traced_benchmark_tool_dir={self.traced_tool_dir}",
str(self.build_dir)
])
self.assertRaises(
ValueError, lambda: benchmark_config.BenchmarkConfig.build_from_args(
args=args, git_commit_hash="abcd"))
def test_build_from_args_with_e2e_test_artifacts_dir(self):
with tempfile.TemporaryDirectory() as e2e_test_artifacts_dir:
exec_bench_config = pathlib.Path(
e2e_test_artifacts_dir) / "exec_bench_config.json"
exec_bench_config.touch()
args = common_arguments.Parser().parse_args([
f"--tmp_dir={self.tmp_dir}",
f"--normal_benchmark_tool_dir={self.normal_tool_dir}",
f"--e2e_test_artifacts_dir={e2e_test_artifacts_dir}",
f"--execution_benchmark_config={exec_bench_config}",
f"--target_device_name=device_a",
])
config = benchmark_config.BenchmarkConfig.build_from_args(
args=args, git_commit_hash="abcd")
self.assertEqual(config.root_benchmark_dir,
pathlib.Path(e2e_test_artifacts_dir))
def test_build_from_args_with_execution_benchmark_config_and_build_dir(self):
with tempfile.TemporaryDirectory() as e2e_test_artifacts_dir:
exec_bench_config = pathlib.Path(
e2e_test_artifacts_dir) / "exec_bench_config.json"
exec_bench_config.touch()
args = common_arguments.Parser().parse_args([
f"--tmp_dir={self.tmp_dir}",
f"--normal_benchmark_tool_dir={self.normal_tool_dir}",
f"--execution_benchmark_config={exec_bench_config}",
f"--target_device_name=device_a",
str(self.build_dir)
])
config = benchmark_config.BenchmarkConfig.build_from_args(
args=args, git_commit_hash="abcd")
self.assertEqual(
config.root_benchmark_dir,
self.build_dir / benchmark_config.E2E_TEST_ARTIFACTS_REL_PATH)
if __name__ == "__main__":
unittest.main()