blob: a41f326ba1205ff6070d6783e4a3d6e3323f57f2 [file] [log] [blame]
# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import re
import subprocess
import time
from common.benchmark_command import BenchmarkCommand
# Regexes for retrieving memory information.
_VMHWM_REGEX = re.compile(r".*?VmHWM:.*?(\d+) kB.*")
_VMRSS_REGEX = re.compile(r".*?VmRSS:.*?(\d+) kB.*")
_RSSFILE_REGEX = re.compile(r".*?RssFile:.*?(\d+) kB.*")
def run_command(benchmark_command: BenchmarkCommand) -> list[float]:
"""Runs `benchmark_command` and polls for memory consumption statistics.
Args:
benchmark_command: A `BenchmarkCommand` object containing information on how to run the benchmark and parse the output.
Returns:
An array containing values for [`latency`, `vmhwm`, `vmrss`, `rssfile`]
"""
command = benchmark_command.generate_benchmark_command()
print("\n\nRunning command:\n" + " ".join(command))
benchmark_process = subprocess.Popen(command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
# Keep a record of the highest VmHWM corresponding VmRSS and RssFile values.
vmhwm = 0
vmrss = 0
rssfile = 0
while benchmark_process.poll() is None:
pid_status = subprocess.run(
["cat", "/proc/" + str(benchmark_process.pid) + "/status"],
capture_output=True)
output = pid_status.stdout.decode()
vmhwm_matches = _VMHWM_REGEX.search(output)
vmrss_matches = _VMRSS_REGEX.search(output)
rssfile_matches = _RSSFILE_REGEX.search(output)
if vmhwm_matches and vmrss_matches and rssfile_matches:
curr_vmhwm = float(vmhwm_matches.group(1))
if curr_vmhwm > vmhwm:
vmhwm = curr_vmhwm
vmrss = float(vmrss_matches.group(1))
rssfile = float(rssfile_matches.group(1))
time.sleep(0.5)
stdout_data, _ = benchmark_process.communicate()
if benchmark_process.returncode != 0:
print(f"Warning! Benchmark command failed with return code:"
f" {benchmark_process.returncode}")
return [0, 0, 0, 0]
else:
print(stdout_data.decode())
latency_ms = benchmark_command.parse_latency_from_output(stdout_data.decode())
return [latency_ms, vmhwm, vmrss, rssfile]