| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| import argparse |
| import multiprocessing |
| import os |
| import random |
| import re |
| import subprocess |
| import sys |
| import time |
| |
| from collections import namedtuple |
| from enum import Enum |
| |
| parser = argparse.ArgumentParser( |
| prog="test_jax.py", description="Run jax testsuite hermetically" |
| ) |
| parser.add_argument("testfiles", nargs="*") |
| parser.add_argument("-t", "--timeout", default=60) |
| parser.add_argument("-l", "--logdir", default="/tmp/jaxtest") |
| parser.add_argument("-p", "--passing", default=None) |
| parser.add_argument("-f", "--failing", default=None) |
| parser.add_argument("-e", "--expected", default=None) |
| parser.add_argument("-j", "--jobs", default=None) |
| |
| args = parser.parse_args() |
| |
| PYTEST_CMD = [ |
| "pytest", |
| "-p", |
| "openxla_pjrt_artifacts", |
| f"--openxla-pjrt-artifact-dir={args.logdir}", |
| ] |
| |
| |
| def get_test(test): |
| print("Fetching from:", test) |
| stdout = subprocess.run(PYTEST_CMD + ["--setup-only", test], capture_output=True) |
| lst = re.findall("::[^ ]*::[^ ]*", str(stdout)) |
| return [test + func for func in lst] |
| |
| |
| def get_tests(tests): |
| fulltestlist = [] |
| with multiprocessing.Pool(os.cpu_count()) as p: |
| fulltestlist = p.map(get_test, tests) |
| fulltestlist = sorted([i for lst in fulltestlist for i in lst]) |
| return fulltestlist |
| |
| |
| def generate_test_commands(tests): |
| cmds = [] |
| for test in tests: |
| test_cmd = PYTEST_CMD + [test] |
| cmds.append(test_cmd) |
| |
| return cmds |
| |
| |
| TestCase = namedtuple("TestCase", ["test", "timeout"]) |
| TestResult = Enum("TestResult", ["SUCCESS", "FAILURE", "TIMEOUT"]) |
| |
| |
| def exec_test(testcase): |
| command, timeout = testcase |
| if float(timeout) > 0: |
| command = ["timeout", f"{timeout}"] + command |
| |
| start = time.perf_counter() |
| result = subprocess.run(command, capture_output=True) |
| end = time.perf_counter() |
| ellapsed = end - start |
| timedout = (float(timeout) > 0) and (ellapsed > float(timeout)) |
| |
| if result.returncode == 0: |
| sys.stdout.write(".") |
| sys.stdout.flush() |
| return TestResult.SUCCESS |
| |
| if timedout: |
| sys.stdout.write("t") |
| sys.stdout.flush() |
| return TestResult.TIMEOUT |
| |
| sys.stdout.write("f") |
| sys.stdout.flush() |
| return TestResult.FAILURE |
| |
| |
| def exec_testsuite(commands, jobs, timeout): |
| random.shuffle(commands) |
| withTimeout = list(map(lambda x: TestCase(x, timeout), commands)) |
| |
| results = [] |
| with multiprocessing.Pool(int(jobs)) as p: |
| results = p.map(exec_test, withTimeout) |
| |
| passing, timeout, failing = [], [], [] |
| for result, cmd in zip(results, commands): |
| if result == TestResult.SUCCESS: |
| passing.append(cmd) |
| |
| if result == TestResult.TIMEOUT: |
| timeout.append(cmd) |
| |
| if result == TestResult.FAILURE: |
| failing.append(cmd) |
| print("") |
| |
| return passing, timeout, failing |
| |
| |
| def get_testnames(cmd): |
| names = [] |
| for c in cmd: |
| testname = " ".join(c) |
| testname = re.search("[^ /]*::[^ ]*::[^ ]*", testname)[0] |
| names.append(testname) |
| return names |
| |
| |
| def write_results(filename, results): |
| if filename is not None: |
| with open(filename, "w") as f: |
| for line in results: |
| f.write(line + "\n") |
| |
| |
| def load_results(filename): |
| if not filename or not os.path.isfile(filename): |
| return [] |
| expected = [] |
| with open(filename, "r") as f: |
| for line in f: |
| expected.append(line.strip()) |
| return expected |
| |
| |
| def compare_results(expected, passing): |
| passing = set(passing) |
| expected = set(expected) |
| new_failures = expected - passing |
| new_passing = passing - expected |
| return new_passing, new_failures |
| |
| |
| print("Querying All Tests") |
| tests = get_tests(args.testfiles) |
| |
| print("Generating test suite") |
| commands = generate_test_commands(tests) |
| |
| print( |
| f"Executing {len(commands)} tests across {args.jobs} threads with timeout = {args.timeout}" |
| ) |
| passing, timeout, failing = exec_testsuite( |
| commands, jobs=args.jobs, timeout=args.timeout |
| ) |
| |
| expected = load_results(args.expected) |
| |
| # Break into passing vs failing |
| failing = failing + timeout |
| |
| # Get the testnames |
| passing = get_testnames(passing) |
| failing = get_testnames(failing) |
| |
| write_results(args.passing, passing) |
| write_results(args.failing, failing) |
| |
| print("Total:", len(commands)) |
| print("Passing:", len(passing)) |
| print("Failing:", len(failing)) |
| print("Failing (timed out):", len(timeout)) |
| |
| if expected: |
| new_passing, new_failures = compare_results(expected, passing) |
| |
| if new_passing: |
| print("Newly Passing Tests:") |
| for test in new_passing: |
| print(" ", test) |
| |
| if new_failures: |
| print("Newly Failing Tests:") |
| for test in new_failures: |
| print(" ", test) |
| |
| if len(expected) > len(passing): |
| exit(1) |