|  | # Copyright 2022 The IREE Authors | 
|  | # | 
|  | # Licensed under the Apache License v2.0 with LLVM Exceptions. | 
|  | # See https://llvm.org/LICENSE.txt for license information. | 
|  | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  |  | 
|  | import re | 
|  | import subprocess | 
|  | import time | 
|  |  | 
|  | from common.benchmark_command import BenchmarkCommand | 
|  |  | 
|  | # Regexes for retrieving memory information. | 
|  | _VMHWM_REGEX = re.compile(r".*?VmHWM:.*?(\d+) kB.*") | 
|  | _VMRSS_REGEX = re.compile(r".*?VmRSS:.*?(\d+) kB.*") | 
|  | _RSSFILE_REGEX = re.compile(r".*?RssFile:.*?(\d+) kB.*") | 
|  |  | 
|  |  | 
|  | def run_command(benchmark_command: BenchmarkCommand) -> list[float]: | 
|  | """Runs `benchmark_command` and polls for memory consumption statistics. | 
|  | Args: | 
|  | benchmark_command: A `BenchmarkCommand` object containing information on how to run the benchmark and parse the output. | 
|  | Returns: | 
|  | An array containing values for [`latency`, `vmhwm`, `vmrss`, `rssfile`] | 
|  | """ | 
|  | command = benchmark_command.generate_benchmark_command() | 
|  | print("\n\nRunning command:\n" + " ".join(command)) | 
|  | benchmark_process = subprocess.Popen(command, | 
|  | stdout=subprocess.PIPE, | 
|  | stderr=subprocess.STDOUT) | 
|  |  | 
|  | # Keep a record of the highest VmHWM corresponding VmRSS and RssFile values. | 
|  | vmhwm = 0 | 
|  | vmrss = 0 | 
|  | rssfile = 0 | 
|  | while benchmark_process.poll() is None: | 
|  | pid_status = subprocess.run( | 
|  | ["cat", "/proc/" + str(benchmark_process.pid) + "/status"], | 
|  | capture_output=True) | 
|  | output = pid_status.stdout.decode() | 
|  | vmhwm_matches = _VMHWM_REGEX.search(output) | 
|  | vmrss_matches = _VMRSS_REGEX.search(output) | 
|  | rssfile_matches = _RSSFILE_REGEX.search(output) | 
|  |  | 
|  | if vmhwm_matches and vmrss_matches and rssfile_matches: | 
|  | curr_vmhwm = float(vmhwm_matches.group(1)) | 
|  | if curr_vmhwm > vmhwm: | 
|  | vmhwm = curr_vmhwm | 
|  | vmrss = float(vmrss_matches.group(1)) | 
|  | rssfile = float(rssfile_matches.group(1)) | 
|  |  | 
|  | time.sleep(0.5) | 
|  |  | 
|  | stdout_data, _ = benchmark_process.communicate() | 
|  |  | 
|  | if benchmark_process.returncode != 0: | 
|  | print(f"Warning! Benchmark command failed with return code:" | 
|  | f" {benchmark_process.returncode}") | 
|  | return [0, 0, 0, 0] | 
|  | else: | 
|  | print(stdout_data.decode()) | 
|  |  | 
|  | latency_ms = benchmark_command.parse_latency_from_output(stdout_data.decode()) | 
|  | return [latency_ms, vmhwm, vmrss, rssfile] |