| # Copyright 2022 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| # All benchmarks' relative path against root build directory. |
| |
| from argparse import Namespace |
| from dataclasses import dataclass |
| from typing import Optional |
| import pathlib |
| import re |
| |
| from common import benchmark_definition |
| |
| BENCHMARK_RESULTS_REL_PATH = "benchmark-results" |
| |
| |
| @dataclass |
| class BenchmarkConfig: |
| """Represents the settings to run benchmarks. |
| |
| tmp_dir: per-commit temporary directory. |
| root_benchmark_dir: the root directory path/URL containing the built |
| benchmark suites. |
| benchmark_results_dir: the directory to store benchmark results files. |
| git_commit_hash: the git commit hash. |
| benchmark_tool_dir: the path to the tool directory. |
| driver_filter: filter benchmarks to those whose driver matches this regex |
| (or all if this is None). |
| model_name_filter: filter benchmarks to those whose model name matches this |
| regex (or all if this is None). |
| mode_filter: filter benchmarks to those whose benchmarking mode matches this |
| regex (or all if this is None). |
| keep_going: whether to proceed if an individual run fails. Exceptions will |
| logged and returned. |
| benchmark_min_time: min number of seconds to run the benchmark for, if |
| specified. Otherwise, the benchmark will be repeated a fixed number of |
| times. |
| continue_from_previous: skip the benchmarks if their results are found in |
| the benchmark_results_dir. |
| verify: verify the output if model's expected output is available. |
| """ |
| |
| tmp_dir: pathlib.Path |
| root_benchmark_dir: benchmark_definition.ResourceLocation |
| benchmark_results_dir: pathlib.Path |
| git_commit_hash: str |
| |
| benchmark_tool_dir: Optional[pathlib.Path] = None |
| |
| driver_filter: Optional[str] = None |
| model_name_filter: Optional[str] = None |
| mode_filter: Optional[str] = None |
| use_compatible_filter: bool = False |
| |
| keep_going: bool = False |
| benchmark_min_time: float = 0 |
| continue_from_previous: bool = False |
| verify: bool = False |
| |
| @staticmethod |
| def build_from_args(args: Namespace, git_commit_hash: str): |
| """Build config from command arguments. |
| |
| Args: |
| args: the command arguments. |
| git_commit_hash: the git commit hash of IREE. |
| """ |
| |
| def real_path_or_none(path: Optional[pathlib.Path]) -> Optional[pathlib.Path]: |
| return path.resolve() if path else None |
| |
| if not args.benchmark_tool_dir: |
| raise ValueError("--benchmark_tool_dir should be specified.") |
| |
| per_commit_tmp_dir: pathlib.Path = (args.tmp_dir / git_commit_hash).resolve() |
| |
| root_benchmark_dir = args.e2e_test_artifacts_dir |
| # Convert the local path into Path object. |
| if re.match("^[^:]+://", str(root_benchmark_dir)): |
| root_benchmark_dir = benchmark_definition.ResourceLocation.build_url( |
| root_benchmark_dir |
| ) |
| else: |
| root_benchmark_dir = benchmark_definition.ResourceLocation.build_local_path( |
| root_benchmark_dir |
| ) |
| |
| return BenchmarkConfig( |
| tmp_dir=per_commit_tmp_dir, |
| root_benchmark_dir=root_benchmark_dir, |
| benchmark_results_dir=per_commit_tmp_dir / BENCHMARK_RESULTS_REL_PATH, |
| git_commit_hash=git_commit_hash, |
| benchmark_tool_dir=real_path_or_none(args.benchmark_tool_dir), |
| driver_filter=args.driver_filter_regex, |
| model_name_filter=args.model_name_regex, |
| mode_filter=args.mode_regex, |
| use_compatible_filter=args.compatible_only, |
| keep_going=args.keep_going, |
| benchmark_min_time=args.benchmark_min_time, |
| continue_from_previous=args.continue_from_previous, |
| verify=args.verify, |
| ) |