blob: 3bc5db30da90e19116e4c10394d19a32a4474e3f [file] [log] [blame]
mariecwhitec3bfbe22022-05-31 11:49:48 -07001# Copyright 2022 The IREE Authors
2#
3# Licensed under the Apache License v2.0 with LLVM Exceptions.
4# See https://llvm.org/LICENSE.txt for license information.
5# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7import re
8import subprocess
9import time
10
11from common.benchmark_command import BenchmarkCommand
12
13# Regexes for retrieving memory information.
14_VMHWM_REGEX = re.compile(r".*?VmHWM:.*?(\d+) kB.*")
15_VMRSS_REGEX = re.compile(r".*?VmRSS:.*?(\d+) kB.*")
16_RSSFILE_REGEX = re.compile(r".*?RssFile:.*?(\d+) kB.*")
17
18
19def run_command(benchmark_command: BenchmarkCommand) -> list[float]:
Jakub Kuderskibe24f022023-06-21 14:44:18 -040020 """Runs `benchmark_command` and polls for memory consumption statistics.
21 Args:
22 benchmark_command: A `BenchmarkCommand` object containing information on how to run the benchmark and parse the output.
23 Returns:
24 An array containing values for [`latency`, `vmhwm`, `vmrss`, `rssfile`]
25 """
26 command = benchmark_command.generate_benchmark_command()
27 print("\n\nRunning command:\n" + " ".join(command))
28 benchmark_process = subprocess.Popen(
29 command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
30 )
mariecwhitec3bfbe22022-05-31 11:49:48 -070031
Jakub Kuderskibe24f022023-06-21 14:44:18 -040032 # Keep a record of the highest VmHWM corresponding VmRSS and RssFile values.
33 vmhwm = 0
34 vmrss = 0
35 rssfile = 0
36 while benchmark_process.poll() is None:
37 pid_status = subprocess.run(
38 ["cat", "/proc/" + str(benchmark_process.pid) + "/status"],
39 capture_output=True,
40 )
41 output = pid_status.stdout.decode()
42 vmhwm_matches = _VMHWM_REGEX.search(output)
43 vmrss_matches = _VMRSS_REGEX.search(output)
44 rssfile_matches = _RSSFILE_REGEX.search(output)
mariecwhitec3bfbe22022-05-31 11:49:48 -070045
Jakub Kuderskibe24f022023-06-21 14:44:18 -040046 if vmhwm_matches and vmrss_matches and rssfile_matches:
47 curr_vmhwm = float(vmhwm_matches.group(1))
48 if curr_vmhwm > vmhwm:
49 vmhwm = curr_vmhwm
50 vmrss = float(vmrss_matches.group(1))
51 rssfile = float(rssfile_matches.group(1))
mariecwhitec3bfbe22022-05-31 11:49:48 -070052
Jakub Kuderskibe24f022023-06-21 14:44:18 -040053 time.sleep(0.5)
mariecwhitec3bfbe22022-05-31 11:49:48 -070054
Jakub Kuderskibe24f022023-06-21 14:44:18 -040055 stdout_data, _ = benchmark_process.communicate()
mariecwhitec3bfbe22022-05-31 11:49:48 -070056
Jakub Kuderskibe24f022023-06-21 14:44:18 -040057 if benchmark_process.returncode != 0:
58 print(
59 f"Warning! Benchmark command failed with return code:"
60 f" {benchmark_process.returncode}"
61 )
62 return [0, 0, 0, 0]
63 else:
64 print(stdout_data.decode())
mariecwhitec3bfbe22022-05-31 11:49:48 -070065
Jakub Kuderskibe24f022023-06-21 14:44:18 -040066 latency_ms = benchmark_command.parse_latency_from_output(stdout_data.decode())
67 return [latency_ms, vmhwm, vmrss, rssfile]