blob: 625a2c6b2d68997183ce65343509eeb11ac6f45f [file] [log] [blame]
Wyatt Hepler13225142019-11-26 14:14:49 -08001#!/usr/bin/env python3
2
3# Copyright 2019 The Pigweed Authors
4#
5# Licensed under the Apache License, Version 2.0 (the "License"); you may not
6# use this file except in compliance with the License. You may obtain a copy of
7# the License at
8#
9# https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14# License for the specific language governing permissions and limitations under
15# the License.
16"""Checks that C and C++ source files match clang-format's formatting."""
17
18import argparse
19import difflib
20import os
21import subprocess
22import sys
23from typing import Container, Iterable, List, Optional
24
25SOURCE_EXTENSIONS = frozenset(['.h', '.hh', '.hpp', '.c', '.cc', '.cpp'])
26DEFAULT_FORMATTER = 'clang-format'
27
28
29def _make_color(*codes: int):
30 start = ''.join(f'\033[{code}m' for code in codes)
31 return f'{start}{{}}\033[0m'.format if os.name == 'posix' else str
32
33
34color_green = _make_color(32)
35color_red = _make_color(31)
36
37
38def _find_extensions(directory, extensions) -> Iterable[str]:
39 for root, _, files in os.walk(directory):
40 for file in files:
41 if os.path.splitext(file)[1] in extensions:
42 yield os.path.join(root, file)
43
44
45def list_files(paths: Iterable[str], extensions: Container[str]) -> List[str]:
46 """Lists files with C or C++ extensions."""
47 files = set()
48
49 for path in paths:
50 if os.path.isfile(path):
51 files.add(path)
52 else:
53 files.update(_find_extensions(path, extensions))
54
55 return sorted(files)
56
57
58def clang_format(*args: str, formatter='clang-format') -> bytes:
59 """Returns the output of clang-format with the provided arguments."""
60 return subprocess.run([formatter, *args],
61 stdout=subprocess.PIPE,
62 check=True).stdout
63
64
65def _colorize_diff_line(line: str) -> str:
66 if line.startswith('-') and not line.startswith('--- '):
67 return color_red(line)
68 if line.startswith('+') and not line.startswith('+++ '):
69 return color_green(line)
70 return line
71
72
73def colorize_diff(lines: Iterable[str]) -> str:
74 """Takes a diff str or list of str lines and returns a colorized version."""
75 if isinstance(lines, str):
76 lines = lines.splitlines(True)
77
78 return ''.join(_colorize_diff_line(line) for line in lines)
79
80
81def clang_format_diff(path, formatter=DEFAULT_FORMATTER) -> Optional[str]:
82 """Returns a diff comparing clang-format's output to the path's contents."""
83 with open(path, 'rb') as fd:
84 current = fd.read()
85
86 formatted = clang_format(path, formatter=formatter)
87
88 if formatted != current:
89 diff = difflib.unified_diff(
90 current.decode(errors='replace').splitlines(True),
91 formatted.decode(errors='replace').splitlines(True),
92 f'{path} (original)', f'{path} (reformatted)')
93
94 return colorize_diff(diff)
95
96 return None
97
98
99def check_format(files, formatter=DEFAULT_FORMATTER) -> List[str]:
100 """Diffs files against clang-format; returns paths that did not match."""
101 errors = []
102
103 for path in files:
104 difference = clang_format_diff(path, formatter)
105 if difference:
106 errors.append(path)
107 print(difference)
108
109 if errors:
110 print(f'--> Files with formatting errors: {len(errors)}',
111 file=sys.stderr)
112 print(' ', '\n '.join(errors), file=sys.stderr)
113
114 return errors
115
116
117def _main(paths, fix, formatter) -> int:
118 """Checks or fixes formatting."""
119 files = list_files(paths, SOURCE_EXTENSIONS)
120
121 if fix:
122 for path in files:
123 clang_format(path, '-i', formatter=formatter)
124 return 0
125
126 errors = check_format(files, formatter)
127 return 1 if errors else 0
128
129
130def _parse_args() -> argparse.Namespace:
131 """Parse and return command line arguments."""
132 parser = argparse.ArgumentParser(description=__doc__)
133
134 parser.add_argument(
135 'paths',
136 default=[os.getcwd()],
137 nargs='*',
138 help=('Files or directories to check. '
139 'Within a directory, only C or C++ files are checked.'))
140 parser.add_argument('--fix',
141 action='store_true',
142 help='Apply clang-format fixes in place.')
143 parser.add_argument('--formatter',
144 default=DEFAULT_FORMATTER,
145 help='The clang-format binary to use.')
146
147 return parser.parse_args()
148
149
150if __name__ == '__main__':
151 sys.exit(_main(**vars(_parse_args())))