Wyatt Hepler | 1322514 | 2019-11-26 14:14:49 -0800 | [diff] [blame] | 1 | #!/usr/bin/env python3 |
| 2 | |
| 3 | # Copyright 2019 The Pigweed Authors |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the "License"); you may not |
| 6 | # use this file except in compliance with the License. You may obtain a copy of |
| 7 | # the License at |
| 8 | # |
| 9 | # https://www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| 13 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| 14 | # License for the specific language governing permissions and limitations under |
| 15 | # the License. |
| 16 | """Checks that C and C++ source files match clang-format's formatting.""" |
| 17 | |
| 18 | import argparse |
| 19 | import difflib |
| 20 | import os |
| 21 | import subprocess |
| 22 | import sys |
| 23 | from typing import Container, Iterable, List, Optional |
| 24 | |
| 25 | SOURCE_EXTENSIONS = frozenset(['.h', '.hh', '.hpp', '.c', '.cc', '.cpp']) |
| 26 | DEFAULT_FORMATTER = 'clang-format' |
| 27 | |
| 28 | |
| 29 | def _make_color(*codes: int): |
| 30 | start = ''.join(f'\033[{code}m' for code in codes) |
| 31 | return f'{start}{{}}\033[0m'.format if os.name == 'posix' else str |
| 32 | |
| 33 | |
| 34 | color_green = _make_color(32) |
| 35 | color_red = _make_color(31) |
| 36 | |
| 37 | |
| 38 | def _find_extensions(directory, extensions) -> Iterable[str]: |
| 39 | for root, _, files in os.walk(directory): |
| 40 | for file in files: |
| 41 | if os.path.splitext(file)[1] in extensions: |
| 42 | yield os.path.join(root, file) |
| 43 | |
| 44 | |
| 45 | def list_files(paths: Iterable[str], extensions: Container[str]) -> List[str]: |
| 46 | """Lists files with C or C++ extensions.""" |
| 47 | files = set() |
| 48 | |
| 49 | for path in paths: |
| 50 | if os.path.isfile(path): |
| 51 | files.add(path) |
| 52 | else: |
| 53 | files.update(_find_extensions(path, extensions)) |
| 54 | |
| 55 | return sorted(files) |
| 56 | |
| 57 | |
| 58 | def clang_format(*args: str, formatter='clang-format') -> bytes: |
| 59 | """Returns the output of clang-format with the provided arguments.""" |
| 60 | return subprocess.run([formatter, *args], |
| 61 | stdout=subprocess.PIPE, |
| 62 | check=True).stdout |
| 63 | |
| 64 | |
| 65 | def _colorize_diff_line(line: str) -> str: |
| 66 | if line.startswith('-') and not line.startswith('--- '): |
| 67 | return color_red(line) |
| 68 | if line.startswith('+') and not line.startswith('+++ '): |
| 69 | return color_green(line) |
| 70 | return line |
| 71 | |
| 72 | |
| 73 | def colorize_diff(lines: Iterable[str]) -> str: |
| 74 | """Takes a diff str or list of str lines and returns a colorized version.""" |
| 75 | if isinstance(lines, str): |
| 76 | lines = lines.splitlines(True) |
| 77 | |
| 78 | return ''.join(_colorize_diff_line(line) for line in lines) |
| 79 | |
| 80 | |
| 81 | def clang_format_diff(path, formatter=DEFAULT_FORMATTER) -> Optional[str]: |
| 82 | """Returns a diff comparing clang-format's output to the path's contents.""" |
| 83 | with open(path, 'rb') as fd: |
| 84 | current = fd.read() |
| 85 | |
| 86 | formatted = clang_format(path, formatter=formatter) |
| 87 | |
| 88 | if formatted != current: |
| 89 | diff = difflib.unified_diff( |
| 90 | current.decode(errors='replace').splitlines(True), |
| 91 | formatted.decode(errors='replace').splitlines(True), |
| 92 | f'{path} (original)', f'{path} (reformatted)') |
| 93 | |
| 94 | return colorize_diff(diff) |
| 95 | |
| 96 | return None |
| 97 | |
| 98 | |
| 99 | def check_format(files, formatter=DEFAULT_FORMATTER) -> List[str]: |
| 100 | """Diffs files against clang-format; returns paths that did not match.""" |
| 101 | errors = [] |
| 102 | |
| 103 | for path in files: |
| 104 | difference = clang_format_diff(path, formatter) |
| 105 | if difference: |
| 106 | errors.append(path) |
| 107 | print(difference) |
| 108 | |
| 109 | if errors: |
| 110 | print(f'--> Files with formatting errors: {len(errors)}', |
| 111 | file=sys.stderr) |
| 112 | print(' ', '\n '.join(errors), file=sys.stderr) |
| 113 | |
| 114 | return errors |
| 115 | |
| 116 | |
| 117 | def _main(paths, fix, formatter) -> int: |
| 118 | """Checks or fixes formatting.""" |
| 119 | files = list_files(paths, SOURCE_EXTENSIONS) |
| 120 | |
| 121 | if fix: |
| 122 | for path in files: |
| 123 | clang_format(path, '-i', formatter=formatter) |
| 124 | return 0 |
| 125 | |
| 126 | errors = check_format(files, formatter) |
| 127 | return 1 if errors else 0 |
| 128 | |
| 129 | |
| 130 | def _parse_args() -> argparse.Namespace: |
| 131 | """Parse and return command line arguments.""" |
| 132 | parser = argparse.ArgumentParser(description=__doc__) |
| 133 | |
| 134 | parser.add_argument( |
| 135 | 'paths', |
| 136 | default=[os.getcwd()], |
| 137 | nargs='*', |
| 138 | help=('Files or directories to check. ' |
| 139 | 'Within a directory, only C or C++ files are checked.')) |
| 140 | parser.add_argument('--fix', |
| 141 | action='store_true', |
| 142 | help='Apply clang-format fixes in place.') |
| 143 | parser.add_argument('--formatter', |
| 144 | default=DEFAULT_FORMATTER, |
| 145 | help='The clang-format binary to use.') |
| 146 | |
| 147 | return parser.parse_args() |
| 148 | |
| 149 | |
| 150 | if __name__ == '__main__': |
| 151 | sys.exit(_main(**vars(_parse_args()))) |