pw_doctor: Run checks in parallel

Run doctor checks in parallel and allow checks to submit their own
parallel jobs. This substantially speeds up pw doctor.

Change-Id: Ib09919022113daadf4d605f37c3c715f325ded13
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/13101
Reviewed-by: Rob Mohr <mohrr@google.com>
Reviewed-by: Keir Mierle <keir@google.com>
Commit-Queue: Wyatt Hepler <hepler@google.com>
diff --git a/pw_doctor/py/pw_doctor/doctor.py b/pw_doctor/py/pw_doctor/doctor.py
index b013636..1402047 100755
--- a/pw_doctor/py/pw_doctor/doctor.py
+++ b/pw_doctor/py/pw_doctor/doctor.py
@@ -15,6 +15,7 @@
 """Checks if the environment is set up correctly for Pigweed."""
 
 import argparse
+from concurrent import futures
 import logging
 import json
 import os
@@ -23,7 +24,7 @@
 import subprocess
 import sys
 import tempfile
-from typing import Callable, List
+from typing import Callable, Iterable, List, Set
 
 
 def call_stdout(*args, **kwargs):
@@ -36,35 +37,82 @@
     pass
 
 
+class Doctor:
+    def __init__(self, *, log: logging.Logger = None, strict: bool = False):
+        self.strict = strict
+        self.log = log or logging.getLogger(__name__)
+        self.failures: Set[str] = set()
+
+    def run(self, checks: Iterable[Callable]):
+        with futures.ThreadPoolExecutor() as executor:
+            futures.wait([
+                executor.submit(self._run_check, c, executor) for c in checks
+            ])
+
+    def _run_check(self, check, executor):
+        ctx = DoctorContext(self, check.__name__, executor)
+        try:
+            self.log.debug('Running check %s', ctx.check)
+            check(ctx)
+            ctx.wait()
+        except _Fatal:
+            pass
+        except:  # pylint: disable=bare-except
+            self.failures.add(ctx.check)
+            self.log.exception('%s failed with an unexpected exception',
+                               check.__name__)
+
+        self.log.debug('Completed check %s', ctx.check)
+
+
 class DoctorContext:
-    """Base class for other checks."""
-    def __init__(self, *, strict=False, log=None):
-        self.name = self.__class__.__name__
-        self._strict = strict
-        self._log = log or logging.getLogger(__name__)
-        self.failures = set()
-        self.curr_checker = None
+    """The context object provided to each context function."""
+    def __init__(self, doctor: Doctor, check: str, executor: futures.Executor):
+        self._doctor = doctor
+        self.check = check
+        self._executor = executor
+        self._futures: List[futures.Future] = []
+
+    def submit(self, function, *args, **kwargs):
+        """Starts running the provided function in parallel."""
+        self._futures.append(
+            self._executor.submit(self._run_job, function, *args, **kwargs))
+
+    def wait(self):
+        """Waits for all parallel tasks started with submit() to complete."""
+        futures.wait(self._futures)
+        self._futures.clear()
+
+    def _run_job(self, function, *args, **kwargs):
+        try:
+            function(*args, **kwargs)
+        except _Fatal:
+            pass
+        except:  # pylint: disable=bare-except
+            self._doctor.failures.add(self.check)
+            self._doctor.log.exception(
+                '%s failed with an unexpected exception', self.check)
 
     def fatal(self, fmt, *args, **kwargs):
-        """Same as error() but terminates the checkearly."""
+        """Same as error() but terminates the check early."""
         self.error(fmt, *args, **kwargs)
         raise _Fatal()
 
     def error(self, fmt, *args, **kwargs):
-        self._log.error(fmt, *args, **kwargs)
-        self.failures.add(self.curr_checker)
+        self._doctor.log.error(fmt, *args, **kwargs)
+        self._doctor.failures.add(self.check)
 
     def warning(self, fmt, *args, **kwargs):
-        if self._strict:
+        if self._doctor.strict:
             self.error(fmt, *args, **kwargs)
         else:
-            self._log.warning(fmt, *args, **kwargs)
+            self._doctor.log.warning(fmt, *args, **kwargs)
 
     def info(self, fmt, *args, **kwargs):
-        self._log.info(fmt, *args, **kwargs)
+        self._doctor.log.info(fmt, *args, **kwargs)
 
     def debug(self, fmt, *args, **kwargs):
-        self._log.debug(fmt, *args, **kwargs)
+        self._doctor.log.debug(fmt, *args, **kwargs)
 
 
 def register_into(dest):
@@ -200,10 +248,7 @@
     json_path = root.joinpath('pw_env_setup', 'py', 'pw_env_setup',
                               'cipd_setup', 'pigweed.json')
 
-    with json_path.open() as ins:
-        packages = json.load(ins)
-
-    for package in packages:
+    def check_cipd(package):
         ctx.debug('checking version of %s', package['path'])
         name = [
             part for part in package['path'].split('/') if '{' not in part
@@ -211,7 +256,7 @@
         path = versions_path.joinpath(f'{name}.cipd_version')
         if not path.is_file():
             ctx.debug('no version file')
-            continue
+            return
 
         with path.open() as ins:
             installed = json.load(ins)
@@ -234,33 +279,26 @@
                     'CIPD package %s is out of date, please rerun bootstrap',
                     installed['package_name'])
 
+    for package in json.loads(json_path.read_text()):
+        ctx.submit(check_cipd, package)
 
-def doctor(strict=False, checks=None):
+
+def run_doctor(strict=False, checks=None):
     """Run all the Check subclasses defined in this file."""
 
-    ctx = DoctorContext(strict=strict)
-
     if checks is None:
         checks = tuple(CHECKS)
 
-    ctx.debug('Doctor running %d checks...', len(checks))
-    for check in checks:
-        try:
-            ctx.debug('Running %s...', check.__name__)
-            ctx.curr_checker = check.__name__
-            check(ctx)
+    doctor = Doctor(strict=strict)
+    doctor.log.debug('Doctor running %d checks...', len(checks))
 
-        except _Fatal:
-            pass
+    doctor.run(checks)
 
-        finally:
-            ctx.curr_checker = None
-
-    if ctx.failures:
-        ctx.info('Failed checks: %s', ', '.join(ctx.failures))
+    if doctor.failures:
+        doctor.log.info('Failed checks: %s', ', '.join(doctor.failures))
     else:
-        ctx.info('Environment passes all checks!')
-    return len(ctx.failures)
+        doctor.log.info('Environment passes all checks!')
+    return len(doctor.failures)
 
 
 def main() -> int:
@@ -272,7 +310,7 @@
         help='Run additional checks.',
     )
 
-    return doctor(**vars(parser.parse_args()))
+    return run_doctor(**vars(parser.parse_args()))
 
 
 if __name__ == '__main__':