blob: 625a2c6b2d68997183ce65343509eeb11ac6f45f [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2019 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Checks that C and C++ source files match clang-format's formatting."""
import argparse
import difflib
import os
import subprocess
import sys
from typing import Container, Iterable, List, Optional
SOURCE_EXTENSIONS = frozenset(['.h', '.hh', '.hpp', '.c', '.cc', '.cpp'])
DEFAULT_FORMATTER = 'clang-format'
def _make_color(*codes: int):
start = ''.join(f'\033[{code}m' for code in codes)
return f'{start}{{}}\033[0m'.format if os.name == 'posix' else str
color_green = _make_color(32)
color_red = _make_color(31)
def _find_extensions(directory, extensions) -> Iterable[str]:
for root, _, files in os.walk(directory):
for file in files:
if os.path.splitext(file)[1] in extensions:
yield os.path.join(root, file)
def list_files(paths: Iterable[str], extensions: Container[str]) -> List[str]:
"""Lists files with C or C++ extensions."""
files = set()
for path in paths:
if os.path.isfile(path):
files.add(path)
else:
files.update(_find_extensions(path, extensions))
return sorted(files)
def clang_format(*args: str, formatter='clang-format') -> bytes:
"""Returns the output of clang-format with the provided arguments."""
return subprocess.run([formatter, *args],
stdout=subprocess.PIPE,
check=True).stdout
def _colorize_diff_line(line: str) -> str:
if line.startswith('-') and not line.startswith('--- '):
return color_red(line)
if line.startswith('+') and not line.startswith('+++ '):
return color_green(line)
return line
def colorize_diff(lines: Iterable[str]) -> str:
"""Takes a diff str or list of str lines and returns a colorized version."""
if isinstance(lines, str):
lines = lines.splitlines(True)
return ''.join(_colorize_diff_line(line) for line in lines)
def clang_format_diff(path, formatter=DEFAULT_FORMATTER) -> Optional[str]:
"""Returns a diff comparing clang-format's output to the path's contents."""
with open(path, 'rb') as fd:
current = fd.read()
formatted = clang_format(path, formatter=formatter)
if formatted != current:
diff = difflib.unified_diff(
current.decode(errors='replace').splitlines(True),
formatted.decode(errors='replace').splitlines(True),
f'{path} (original)', f'{path} (reformatted)')
return colorize_diff(diff)
return None
def check_format(files, formatter=DEFAULT_FORMATTER) -> List[str]:
"""Diffs files against clang-format; returns paths that did not match."""
errors = []
for path in files:
difference = clang_format_diff(path, formatter)
if difference:
errors.append(path)
print(difference)
if errors:
print(f'--> Files with formatting errors: {len(errors)}',
file=sys.stderr)
print(' ', '\n '.join(errors), file=sys.stderr)
return errors
def _main(paths, fix, formatter) -> int:
"""Checks or fixes formatting."""
files = list_files(paths, SOURCE_EXTENSIONS)
if fix:
for path in files:
clang_format(path, '-i', formatter=formatter)
return 0
errors = check_format(files, formatter)
return 1 if errors else 0
def _parse_args() -> argparse.Namespace:
"""Parse and return command line arguments."""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'paths',
default=[os.getcwd()],
nargs='*',
help=('Files or directories to check. '
'Within a directory, only C or C++ files are checked.'))
parser.add_argument('--fix',
action='store_true',
help='Apply clang-format fixes in place.')
parser.add_argument('--formatter',
default=DEFAULT_FORMATTER,
help='The clang-format binary to use.')
return parser.parse_args()
if __name__ == '__main__':
sys.exit(_main(**vars(_parse_args())))