pw_doctor: Run checks in parallel
Run doctor checks in parallel and allow checks to submit their own
parallel jobs. This substantially speeds up pw doctor.
Change-Id: Ib09919022113daadf4d605f37c3c715f325ded13
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/13101
Reviewed-by: Rob Mohr <mohrr@google.com>
Reviewed-by: Keir Mierle <keir@google.com>
Commit-Queue: Wyatt Hepler <hepler@google.com>
diff --git a/pw_doctor/py/pw_doctor/doctor.py b/pw_doctor/py/pw_doctor/doctor.py
index b013636..1402047 100755
--- a/pw_doctor/py/pw_doctor/doctor.py
+++ b/pw_doctor/py/pw_doctor/doctor.py
@@ -15,6 +15,7 @@
"""Checks if the environment is set up correctly for Pigweed."""
import argparse
+from concurrent import futures
import logging
import json
import os
@@ -23,7 +24,7 @@
import subprocess
import sys
import tempfile
-from typing import Callable, List
+from typing import Callable, Iterable, List, Set
def call_stdout(*args, **kwargs):
@@ -36,35 +37,82 @@
pass
+class Doctor:
+ def __init__(self, *, log: logging.Logger = None, strict: bool = False):
+ self.strict = strict
+ self.log = log or logging.getLogger(__name__)
+ self.failures: Set[str] = set()
+
+ def run(self, checks: Iterable[Callable]):
+ with futures.ThreadPoolExecutor() as executor:
+ futures.wait([
+ executor.submit(self._run_check, c, executor) for c in checks
+ ])
+
+ def _run_check(self, check, executor):
+ ctx = DoctorContext(self, check.__name__, executor)
+ try:
+ self.log.debug('Running check %s', ctx.check)
+ check(ctx)
+ ctx.wait()
+ except _Fatal:
+ pass
+ except: # pylint: disable=bare-except
+ self.failures.add(ctx.check)
+ self.log.exception('%s failed with an unexpected exception',
+ check.__name__)
+
+ self.log.debug('Completed check %s', ctx.check)
+
+
class DoctorContext:
- """Base class for other checks."""
- def __init__(self, *, strict=False, log=None):
- self.name = self.__class__.__name__
- self._strict = strict
- self._log = log or logging.getLogger(__name__)
- self.failures = set()
- self.curr_checker = None
+ """The context object provided to each context function."""
+ def __init__(self, doctor: Doctor, check: str, executor: futures.Executor):
+ self._doctor = doctor
+ self.check = check
+ self._executor = executor
+ self._futures: List[futures.Future] = []
+
+ def submit(self, function, *args, **kwargs):
+ """Starts running the provided function in parallel."""
+ self._futures.append(
+ self._executor.submit(self._run_job, function, *args, **kwargs))
+
+ def wait(self):
+ """Waits for all parallel tasks started with submit() to complete."""
+ futures.wait(self._futures)
+ self._futures.clear()
+
+ def _run_job(self, function, *args, **kwargs):
+ try:
+ function(*args, **kwargs)
+ except _Fatal:
+ pass
+ except: # pylint: disable=bare-except
+ self._doctor.failures.add(self.check)
+ self._doctor.log.exception(
+ '%s failed with an unexpected exception', self.check)
def fatal(self, fmt, *args, **kwargs):
- """Same as error() but terminates the checkearly."""
+ """Same as error() but terminates the check early."""
self.error(fmt, *args, **kwargs)
raise _Fatal()
def error(self, fmt, *args, **kwargs):
- self._log.error(fmt, *args, **kwargs)
- self.failures.add(self.curr_checker)
+ self._doctor.log.error(fmt, *args, **kwargs)
+ self._doctor.failures.add(self.check)
def warning(self, fmt, *args, **kwargs):
- if self._strict:
+ if self._doctor.strict:
self.error(fmt, *args, **kwargs)
else:
- self._log.warning(fmt, *args, **kwargs)
+ self._doctor.log.warning(fmt, *args, **kwargs)
def info(self, fmt, *args, **kwargs):
- self._log.info(fmt, *args, **kwargs)
+ self._doctor.log.info(fmt, *args, **kwargs)
def debug(self, fmt, *args, **kwargs):
- self._log.debug(fmt, *args, **kwargs)
+ self._doctor.log.debug(fmt, *args, **kwargs)
def register_into(dest):
@@ -200,10 +248,7 @@
json_path = root.joinpath('pw_env_setup', 'py', 'pw_env_setup',
'cipd_setup', 'pigweed.json')
- with json_path.open() as ins:
- packages = json.load(ins)
-
- for package in packages:
+ def check_cipd(package):
ctx.debug('checking version of %s', package['path'])
name = [
part for part in package['path'].split('/') if '{' not in part
@@ -211,7 +256,7 @@
path = versions_path.joinpath(f'{name}.cipd_version')
if not path.is_file():
ctx.debug('no version file')
- continue
+ return
with path.open() as ins:
installed = json.load(ins)
@@ -234,33 +279,26 @@
'CIPD package %s is out of date, please rerun bootstrap',
installed['package_name'])
+ for package in json.loads(json_path.read_text()):
+ ctx.submit(check_cipd, package)
-def doctor(strict=False, checks=None):
+
+def run_doctor(strict=False, checks=None):
"""Run all the Check subclasses defined in this file."""
- ctx = DoctorContext(strict=strict)
-
if checks is None:
checks = tuple(CHECKS)
- ctx.debug('Doctor running %d checks...', len(checks))
- for check in checks:
- try:
- ctx.debug('Running %s...', check.__name__)
- ctx.curr_checker = check.__name__
- check(ctx)
+ doctor = Doctor(strict=strict)
+ doctor.log.debug('Doctor running %d checks...', len(checks))
- except _Fatal:
- pass
+ doctor.run(checks)
- finally:
- ctx.curr_checker = None
-
- if ctx.failures:
- ctx.info('Failed checks: %s', ', '.join(ctx.failures))
+ if doctor.failures:
+ doctor.log.info('Failed checks: %s', ', '.join(doctor.failures))
else:
- ctx.info('Environment passes all checks!')
- return len(ctx.failures)
+ doctor.log.info('Environment passes all checks!')
+ return len(doctor.failures)
def main() -> int:
@@ -272,7 +310,7 @@
help='Run additional checks.',
)
- return doctor(**vars(parser.parse_args()))
+ return run_doctor(**vars(parser.parse_args()))
if __name__ == '__main__':