Created
October 26, 2023 16:01
-
-
Save mrcljx/57d127e97ed33063f6a5b51a55d34036 to your computer and use it in GitHub Desktop.
EOF Fixer (exactly one newline)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Check if files end with a single newline. | |
""" | |
import argparse | |
import io | |
import logging | |
import pathlib | |
import sys | |
from collections.abc import Iterator | |
from typing import IO | |
import pytest | |
from identify import identify | |
log = logging.getLogger(__name__) | |
ORD_NL = ord("\n") | |
ORD_CR = ord("\r") | |
def _check_tail(data: Iterator[int]) -> bool: | |
""" | |
Check if the tail of a file is formatted correctly, | |
i.e. ends with a single newline. | |
""" | |
cur = next(data, None) | |
if cur is None: | |
return True | |
if cur != ORD_NL: | |
return False | |
cur = next(data, None) | |
if cur == ORD_CR: | |
cur = next(data, None) | |
if cur is None: | |
return True | |
return not str(cur).isspace() | |
def _read_tail(f: IO[bytes], n: int) -> bytes: | |
"""Read the tail of a file.""" | |
try: | |
f.seek(-n, io.SEEK_END) | |
except OSError: | |
f.seek(0, io.SEEK_SET) | |
return f.read(n) | |
def _check_fast(f: IO[bytes]) -> bool: | |
"""Check if a file is malformed using a fast method.""" | |
return _check_tail(reversed(_read_tail(f, 3))) | |
@pytest.mark.parametrize( | |
("example", "expected"), | |
[ | |
(b"\r", False), | |
(b"x", False), | |
(b"", True), | |
(b"\n", True), | |
(b"x\n", True), | |
(b"\r\n", True), | |
(b"\n\n", True), | |
(b"a\r\n", True), | |
], | |
) | |
def test_check_fast(example: bytes, expected: bool) -> None: | |
assert _check_fast(io.BytesIO(example)) == expected | |
def check_file(path: pathlib.Path) -> bool: | |
with open(path, "rb") as f: | |
return _check_fast(f) | |
def fix_file(path: pathlib.Path) -> None: | |
with open(path, "r+") as f: | |
lines = f.readlines() | |
if not lines: | |
return | |
f.seek(0, io.SEEK_SET) | |
f.truncate() | |
ending = "\r\n" if lines[0].endswith("\r\n") else "\n" | |
while lines: | |
if lines[-1].strip(): | |
break | |
lines.pop() | |
if not lines: | |
lines.append(ending) | |
else: | |
lines[-1] = lines[-1].rstrip() + ending | |
f.writelines(lines) | |
@pytest.mark.parametrize( | |
("example", "expected"), | |
[ | |
("\n", "\n"), | |
("a", "a\n"), | |
("a\n\n", "a\n"), | |
("a\na\n", "a\na\n"), | |
], | |
) | |
def test_fix_file(tmp_path: pathlib.Path, example: str, expected: str) -> None: | |
path = tmp_path / "file.txt" | |
path.write_text(example) | |
fix_file(path) | |
assert path.read_text() == expected | |
def main() -> None: | |
parser = argparse.ArgumentParser(description="File end of file") | |
parser.add_argument("--fix", action="store_true") | |
parser.add_argument( | |
"path", | |
type=pathlib.Path, | |
nargs="+", | |
) | |
args = parser.parse_args() | |
logging.basicConfig(level=logging.INFO) | |
fix = args.fix | |
errors = list[pathlib.Path]() | |
for path in args.path: | |
if "text" not in identify.tags_from_path(path): | |
log.debug("SKIP: %s", path) | |
elif check_file(path): | |
log.debug("OK: %s", path) | |
elif fix: | |
log.info("FIX: %s", path) | |
fix_file(path) | |
else: | |
log.error("BAD: %s", path) | |
errors.append(path) | |
if errors: | |
log.error("Found %d malformed files", len(errors)) | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment