blob: a49cc69e76bd89b52051dadee2f59fd2ebb7eeb6 [file] [log] [blame]
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import abc
import re
from typing import Optional
class BenchmarkCommand(abc.ABC):
"""Abstracts a benchmark command."""
def __init__(
self,
benchmark_binary: str,
model_name: str,
num_threads: int,
num_runs: int,
driver: Optional[str] = None,
taskset: Optional[str] = None,
):
self.benchmark_binary = benchmark_binary
self.model_name = model_name
self.taskset = taskset
self.num_threads = num_threads
self.num_runs = num_runs
self.driver = driver
self.args = []
@property
@abc.abstractmethod
def runtime(self):
pass
@abc.abstractmethod
def parse_latency_from_output(self, output: str) -> float:
pass
def generate_benchmark_command(self) -> list[str]:
"""Returns a list of strings that correspond to the command to be run."""
command = []
if self.taskset:
command.append("taskset")
command.append(str(self.taskset))
command.append(self.benchmark_binary)
command.extend(self.args)
return command
class TFLiteBenchmarkCommand(BenchmarkCommand):
"""Represents a TFLite benchmark command."""
def __init__(
self,
benchmark_binary: str,
model_name: str,
model_path: str,
num_threads: int,
num_runs: int,
taskset: Optional[str] = None,
):
super().__init__(
benchmark_binary, model_name, num_threads, num_runs, taskset=taskset
)
self.args.append("--graph=" + model_path)
self._latency_large_regex = re.compile(
r".*?Inference \(avg\): (\d+.?\d*e\+?\d*).*"
)
self._latency_regex = re.compile(r".*?Inference \(avg\): (\d+).*")
@property
def runtime(self):
return "tflite"
def parse_latency_from_output(self, output: str) -> float:
# First match whether a large number has been recorded e.g. 1.18859e+06.
matches = self._latency_large_regex.search(output)
if not matches:
# Otherwise, regular number e.g. 71495.6.
matches = self._latency_regex.search(output)
latency_ms = 0
if matches:
latency_ms = float(matches.group(1)) / 1000
else:
print("Warning! Could not parse latency. Defaulting to 0ms.")
return latency_ms
def generate_benchmark_command(self) -> list[str]:
command = super().generate_benchmark_command()
if self.driver == "gpu":
command.append("--use_gpu=true")
command.append("--num_threads=" + str(self.num_threads))
command.append("--num_runs=" + str(self.num_runs))
return command
class IreeBenchmarkCommand(BenchmarkCommand):
"""Represents an IREE benchmark command."""
def __init__(
self,
benchmark_binary: str,
model_name: str,
model_path: str,
num_threads: int,
num_runs: int,
taskset: Optional[str] = None,
):
super().__init__(
benchmark_binary, model_name, num_threads, num_runs, taskset=taskset
)
self.args.append("--module=" + model_path)
self._latency_regex = re.compile(
r".*?BM_main/process_time/real_time_mean\s+(.*?) ms.*"
)
@property
def runtime(self):
return "iree"
def parse_latency_from_output(self, output: str) -> float:
matches = self._latency_regex.search(output)
latency_ms = 0
if matches:
latency_ms = float(matches.group(1))
else:
print("Warning! Could not parse latency. Defaulting to 0ms.")
return latency_ms
def generate_benchmark_command(self) -> list[str]:
command = super().generate_benchmark_command()
command.append("--device=" + self.driver)
command.append("--task_topology_max_group_count=" + str(self.num_threads))
command.append("--benchmark_repetitions=" + str(self.num_runs))
return command