blob: 7e3de14369c1bb1498329c3b44a9fddb75909add [file] [log] [blame]
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import argparse
import multiprocessing
import os
import random
import re
import subprocess
import sys
import time
from collections import namedtuple
from enum import Enum
parser = argparse.ArgumentParser(prog='test_jax.py',
description='Run jax testsuite hermetically')
parser.add_argument('testfiles', nargs="*")
parser.add_argument('-t', '--timeout', default=60)
parser.add_argument('-l', '--logdir', default="/tmp/jaxtest")
parser.add_argument('-p', '--passing', default=None)
parser.add_argument('-f', '--failing', default=None)
parser.add_argument('-e', '--expected', default=None)
parser.add_argument('-j', '--jobs', default=None)
args = parser.parse_args()
PYTEST_CMD = [
"pytest", "-p", "openxla_pjrt_artifacts",
f"--openxla-pjrt-artifact-dir={args.logdir}"
]
def get_test(test):
print("Fetching from:", test)
stdout = subprocess.run(PYTEST_CMD + ["--setup-only", test],
capture_output=True)
lst = re.findall('::[^ ]*::[^ ]*', str(stdout))
return [test + func for func in lst]
def get_tests(tests):
fulltestlist = []
with multiprocessing.Pool(os.cpu_count()) as p:
fulltestlist = p.map(get_test, tests)
fulltestlist = sorted([i for lst in fulltestlist for i in lst])
return fulltestlist
def generate_test_commands(tests):
cmds = []
for test in tests:
test_cmd = PYTEST_CMD + [test]
cmds.append(test_cmd)
return cmds
TestCase = namedtuple('TestCase', ['test', 'timeout'])
TestResult = Enum('TestResult', ['SUCCESS', 'FAILURE', 'TIMEOUT'])
def exec_test(testcase):
command, timeout = testcase
if float(timeout) > 0:
command = ["timeout", f"{timeout}"] + command
start = time.perf_counter()
result = subprocess.run(command, capture_output=True)
end = time.perf_counter()
ellapsed = end - start
timedout = (float(timeout) > 0) and (ellapsed > float(timeout))
if result.returncode == 0:
sys.stdout.write(".")
sys.stdout.flush()
return TestResult.SUCCESS
if timedout:
sys.stdout.write("t")
sys.stdout.flush()
return TestResult.TIMEOUT
sys.stdout.write("f")
sys.stdout.flush()
return TestResult.FAILURE
def exec_testsuite(commands, jobs, timeout):
random.shuffle(commands)
withTimeout = list(map(lambda x : TestCase(x, timeout), commands))
results = []
with multiprocessing.Pool(int(jobs)) as p:
results = p.map(exec_test, withTimeout)
passing, timeout, failing = [], [], []
for result, cmd in zip(results, commands):
if result == TestResult.SUCCESS:
passing.append(cmd)
if result == TestResult.TIMEOUT:
timeout.append(cmd)
if result == TestResult.FAILURE:
failing.append(cmd)
print("")
return passing, timeout, failing
def get_testnames(cmd):
names = []
for c in cmd:
testname = " ".join(c)
testname = re.search("[^ /]*::[^ ]*::[^ ]*", testname)[0]
names.append(testname)
return names
def write_results(filename, results):
if (filename is not None):
with open(filename, 'w') as f:
for line in results:
f.write(line + "\n")
def load_results(filename):
if not filename or not os.path.isfile(filename):
return []
expected = []
with open(filename, 'r') as f:
for line in f:
expected.append(line.strip())
return expected
def compare_results(expected, passing):
passing = set(passing)
expected = set(expected)
new_failures = expected - passing
new_passing = passing - expected
return new_passing, new_failures
print("Querying All Tests")
tests = get_tests(args.testfiles)
print("Generating test suite")
commands = generate_test_commands(tests)
print(f"Executing {len(commands)} tests across {args.jobs} threads with timeout = {args.timeout}")
passing, timeout, failing = exec_testsuite(commands, jobs=args.jobs, timeout=args.timeout)
expected = load_results(args.expected)
# Break into passing vs failing
failing = failing + timeout
# Get the testnames
passing = get_testnames(passing)
failing = get_testnames(failing)
write_results(args.passing, passing)
write_results(args.failing, failing)
print("Total:", len(commands))
print("Passing:", len(passing))
print("Failing:", len(failing))
print("Failing (timed out):", len(timeout))
if expected:
new_passing, new_failures = compare_results(expected, passing)
if new_passing:
print("Newly Passing Tests:")
for test in new_passing:
print(" ", test)
if new_failures:
print("Newly Failing Tests:")
for test in new_failures:
print(" ", test)
if len(expected) > len(passing):
exit(1)