blob: 5aca1b8db0b83286bd09d6011c7d38b767642de0 [file] [log] [blame]
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import argparse
import multiprocessing
import os
import random
import re
import subprocess
import sys
import time
from collections import namedtuple
from enum import Enum
parser = argparse.ArgumentParser(
prog="test_jax.py", description="Run jax testsuite hermetically"
)
parser.add_argument("testfiles", nargs="*")
parser.add_argument("-t", "--timeout", default=60)
parser.add_argument("-l", "--logdir", default="/tmp/jaxtest")
parser.add_argument("-p", "--passing", default=None)
parser.add_argument("-f", "--failing", default=None)
parser.add_argument("-e", "--expected", default=None)
parser.add_argument("-j", "--jobs", default=None)
args = parser.parse_args()
PYTEST_CMD = [
"pytest",
"-p",
"openxla_pjrt_artifacts",
f"--openxla-pjrt-artifact-dir={args.logdir}",
]
def get_test(test):
print("Fetching from:", test)
stdout = subprocess.run(PYTEST_CMD + ["--setup-only", test], capture_output=True)
lst = re.findall("::[^ ]*::[^ ]*", str(stdout))
return [test + func for func in lst]
def get_tests(tests):
fulltestlist = []
with multiprocessing.Pool(os.cpu_count()) as p:
fulltestlist = p.map(get_test, tests)
fulltestlist = sorted([i for lst in fulltestlist for i in lst])
return fulltestlist
def generate_test_commands(tests):
cmds = []
for test in tests:
test_cmd = PYTEST_CMD + [test]
cmds.append(test_cmd)
return cmds
TestCase = namedtuple("TestCase", ["test", "timeout"])
TestResult = Enum("TestResult", ["SUCCESS", "FAILURE", "TIMEOUT"])
def exec_test(testcase):
command, timeout = testcase
if float(timeout) > 0:
command = ["timeout", f"{timeout}"] + command
start = time.perf_counter()
result = subprocess.run(command, capture_output=True)
end = time.perf_counter()
ellapsed = end - start
timedout = (float(timeout) > 0) and (ellapsed > float(timeout))
if result.returncode == 0:
sys.stdout.write(".")
sys.stdout.flush()
return TestResult.SUCCESS
if timedout:
sys.stdout.write("t")
sys.stdout.flush()
return TestResult.TIMEOUT
sys.stdout.write("f")
sys.stdout.flush()
return TestResult.FAILURE
def exec_testsuite(commands, jobs, timeout):
random.shuffle(commands)
withTimeout = list(map(lambda x: TestCase(x, timeout), commands))
results = []
with multiprocessing.Pool(int(jobs)) as p:
results = p.map(exec_test, withTimeout)
passing, timeout, failing = [], [], []
for result, cmd in zip(results, commands):
if result == TestResult.SUCCESS:
passing.append(cmd)
if result == TestResult.TIMEOUT:
timeout.append(cmd)
if result == TestResult.FAILURE:
failing.append(cmd)
print("")
return passing, timeout, failing
def get_testnames(cmd):
names = []
for c in cmd:
testname = " ".join(c)
testname = re.search("[^ /]*::[^ ]*::[^ ]*", testname)[0]
names.append(testname)
return names
def write_results(filename, results):
if filename is not None:
with open(filename, "w") as f:
for line in results:
f.write(line + "\n")
def load_results(filename):
if not filename or not os.path.isfile(filename):
return []
expected = []
with open(filename, "r") as f:
for line in f:
expected.append(line.strip())
return expected
def compare_results(expected, passing):
passing = set(passing)
expected = set(expected)
new_failures = expected - passing
new_passing = passing - expected
return new_passing, new_failures
print("Querying All Tests")
tests = get_tests(args.testfiles)
print("Generating test suite")
commands = generate_test_commands(tests)
print(
f"Executing {len(commands)} tests across {args.jobs} threads with timeout = {args.timeout}"
)
passing, timeout, failing = exec_testsuite(
commands, jobs=args.jobs, timeout=args.timeout
)
expected = load_results(args.expected)
# Break into passing vs failing
failing = failing + timeout
# Get the testnames
passing = get_testnames(passing)
failing = get_testnames(failing)
write_results(args.passing, passing)
write_results(args.failing, failing)
print("Total:", len(commands))
print("Passing:", len(passing))
print("Failing:", len(failing))
print("Failing (timed out):", len(timeout))
if expected:
new_passing, new_failures = compare_results(expected, passing)
if new_passing:
print("Newly Passing Tests:")
for test in new_passing:
print(" ", test)
if new_failures:
print("Newly Failing Tests:")
for test in new_failures:
print(" ", test)
if len(expected) > len(passing):
exit(1)