|  | # Copyright 2022 The IREE Authors | 
|  | # | 
|  | # Licensed under the Apache License v2.0 with LLVM Exceptions. | 
|  | # See https://llvm.org/LICENSE.txt for license information. | 
|  | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  |  | 
|  | import abc | 
|  | import re | 
|  |  | 
|  | from typing import Optional | 
|  |  | 
|  |  | 
|  | class BenchmarkCommand(abc.ABC): | 
|  | """Abstracts a benchmark command.""" | 
|  |  | 
|  | def __init__(self, | 
|  | benchmark_binary: str, | 
|  | model_name: str, | 
|  | num_threads: int, | 
|  | num_runs: int, | 
|  | driver: Optional[str] = None, | 
|  | taskset: Optional[str] = None): | 
|  | self.benchmark_binary = benchmark_binary | 
|  | self.model_name = model_name | 
|  | self.taskset = taskset | 
|  | self.num_threads = num_threads | 
|  | self.num_runs = num_runs | 
|  | self.driver = driver | 
|  | self.args = [] | 
|  |  | 
|  | @property | 
|  | @abc.abstractmethod | 
|  | def runtime(self): | 
|  | pass | 
|  |  | 
|  | @abc.abstractmethod | 
|  | def parse_latency_from_output(self, output: str) -> float: | 
|  | pass | 
|  |  | 
|  | def generate_benchmark_command(self) -> list[str]: | 
|  | """Returns a list of strings that correspond to the command to be run.""" | 
|  | command = [] | 
|  | if self.taskset: | 
|  | command.append("taskset") | 
|  | command.append(str(self.taskset)) | 
|  | command.append(self.benchmark_binary) | 
|  | command.extend(self.args) | 
|  | return command | 
|  |  | 
|  |  | 
|  | class TFLiteBenchmarkCommand(BenchmarkCommand): | 
|  | """Represents a TFLite benchmark command.""" | 
|  |  | 
|  | def __init__(self, | 
|  | benchmark_binary: str, | 
|  | model_name: str, | 
|  | model_path: str, | 
|  | num_threads: int, | 
|  | num_runs: int, | 
|  | taskset: Optional[str] = None): | 
|  | super().__init__(benchmark_binary, | 
|  | model_name, | 
|  | num_threads, | 
|  | num_runs, | 
|  | taskset=taskset) | 
|  | self.args.append("--graph=" + model_path) | 
|  | self._latency_large_regex = re.compile( | 
|  | r".*?Inference \(avg\): (\d+.?\d*e\+?\d*).*") | 
|  | self._latency_regex = re.compile(r".*?Inference \(avg\): (\d+).*") | 
|  |  | 
|  | @property | 
|  | def runtime(self): | 
|  | return "tflite" | 
|  |  | 
|  | def parse_latency_from_output(self, output: str) -> float: | 
|  | # First match whether a large number has been recorded e.g. 1.18859e+06. | 
|  | matches = self._latency_large_regex.search(output) | 
|  | if not matches: | 
|  | # Otherwise, regular number e.g. 71495.6. | 
|  | matches = self._latency_regex.search(output) | 
|  |  | 
|  | latency_ms = 0 | 
|  | if matches: | 
|  | latency_ms = float(matches.group(1)) / 1000 | 
|  | else: | 
|  | print("Warning! Could not parse latency. Defaulting to 0ms.") | 
|  | return latency_ms | 
|  |  | 
|  | def generate_benchmark_command(self) -> list[str]: | 
|  | command = super().generate_benchmark_command() | 
|  | if self.driver == "gpu": | 
|  | command.append("--use_gpu=true") | 
|  | command.append("--num_threads=" + str(self.num_threads)) | 
|  | command.append("--num_runs=" + str(self.num_runs)) | 
|  | return command | 
|  |  | 
|  |  | 
|  | class IreeBenchmarkCommand(BenchmarkCommand): | 
|  | """Represents an IREE benchmark command.""" | 
|  |  | 
|  | def __init__(self, | 
|  | benchmark_binary: str, | 
|  | model_name: str, | 
|  | model_path: str, | 
|  | num_threads: int, | 
|  | num_runs: int, | 
|  | taskset: Optional[str] = None): | 
|  | super().__init__(benchmark_binary, | 
|  | model_name, | 
|  | num_threads, | 
|  | num_runs, | 
|  | taskset=taskset) | 
|  | self.args.append("--module=" + model_path) | 
|  | self._latency_regex = re.compile( | 
|  | r".*?BM_main/process_time/real_time_mean\s+(.*?) ms.*") | 
|  |  | 
|  | @property | 
|  | def runtime(self): | 
|  | return "iree" | 
|  |  | 
|  | def parse_latency_from_output(self, output: str) -> float: | 
|  | matches = self._latency_regex.search(output) | 
|  | latency_ms = 0 | 
|  | if matches: | 
|  | latency_ms = float(matches.group(1)) | 
|  | else: | 
|  | print("Warning! Could not parse latency. Defaulting to 0ms.") | 
|  | return latency_ms | 
|  |  | 
|  | def generate_benchmark_command(self) -> list[str]: | 
|  | command = super().generate_benchmark_command() | 
|  | command.append("--device=" + self.driver) | 
|  | command.append("--task_topology_group_count=" + str(self.num_threads)) | 
|  | command.append("--benchmark_repetitions=" + str(self.num_runs)) | 
|  | return command |