Skip to content

Instantly share code, notes, and snippets.

@mrcljx
Created October 26, 2023 16:01
Show Gist options
  • Save mrcljx/57d127e97ed33063f6a5b51a55d34036 to your computer and use it in GitHub Desktop.
Save mrcljx/57d127e97ed33063f6a5b51a55d34036 to your computer and use it in GitHub Desktop.
EOF Fixer (exactly one newline)
#!/usr/bin/env python3
"""
Check if files end with a single newline.
"""
import argparse
import io
import logging
import pathlib
import sys
from collections.abc import Iterator
from typing import IO
import pytest
from identify import identify
log = logging.getLogger(__name__)
ORD_NL = ord("\n")
ORD_CR = ord("\r")
def _check_tail(data: Iterator[int]) -> bool:
"""
Check if the tail of a file is formatted correctly,
i.e. ends with a single newline.
"""
cur = next(data, None)
if cur is None:
return True
if cur != ORD_NL:
return False
cur = next(data, None)
if cur == ORD_CR:
cur = next(data, None)
if cur is None:
return True
return not str(cur).isspace()
def _read_tail(f: IO[bytes], n: int) -> bytes:
"""Read the tail of a file."""
try:
f.seek(-n, io.SEEK_END)
except OSError:
f.seek(0, io.SEEK_SET)
return f.read(n)
def _check_fast(f: IO[bytes]) -> bool:
"""Check if a file is malformed using a fast method."""
return _check_tail(reversed(_read_tail(f, 3)))
@pytest.mark.parametrize(
("example", "expected"),
[
(b"\r", False),
(b"x", False),
(b"", True),
(b"\n", True),
(b"x\n", True),
(b"\r\n", True),
(b"\n\n", True),
(b"a\r\n", True),
],
)
def test_check_fast(example: bytes, expected: bool) -> None:
assert _check_fast(io.BytesIO(example)) == expected
def check_file(path: pathlib.Path) -> bool:
with open(path, "rb") as f:
return _check_fast(f)
def fix_file(path: pathlib.Path) -> None:
with open(path, "r+") as f:
lines = f.readlines()
if not lines:
return
f.seek(0, io.SEEK_SET)
f.truncate()
ending = "\r\n" if lines[0].endswith("\r\n") else "\n"
while lines:
if lines[-1].strip():
break
lines.pop()
if not lines:
lines.append(ending)
else:
lines[-1] = lines[-1].rstrip() + ending
f.writelines(lines)
@pytest.mark.parametrize(
("example", "expected"),
[
("\n", "\n"),
("a", "a\n"),
("a\n\n", "a\n"),
("a\na\n", "a\na\n"),
],
)
def test_fix_file(tmp_path: pathlib.Path, example: str, expected: str) -> None:
path = tmp_path / "file.txt"
path.write_text(example)
fix_file(path)
assert path.read_text() == expected
def main() -> None:
parser = argparse.ArgumentParser(description="File end of file")
parser.add_argument("--fix", action="store_true")
parser.add_argument(
"path",
type=pathlib.Path,
nargs="+",
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
fix = args.fix
errors = list[pathlib.Path]()
for path in args.path:
if "text" not in identify.tags_from_path(path):
log.debug("SKIP: %s", path)
elif check_file(path):
log.debug("OK: %s", path)
elif fix:
log.info("FIX: %s", path)
fix_file(path)
else:
log.error("BAD: %s", path)
errors.append(path)
if errors:
log.error("Found %d malformed files", len(errors))
sys.exit(1)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment