Last active
November 4, 2024 05:18
-
-
Save milkey-mouse/ad54429ab562819530b4d336773453c2 to your computer and use it in GitHub Desktop.
Sort imports in Python files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# sort-imports.py: Sort imports in Python files. | |
# | |
# Usage: ./sort-imports.py file1.py file2.py ... | |
# | |
# Sorts imports within "stanzas". A stanza is a block of consecutive import | |
# statements with the same indentation. Each stanza is sorted to contain first | |
# all `from` imports, sorted by module, then by imported names; then `import` | |
# imports, sorted by module. | |
from contextlib import suppress | |
from dataclasses import dataclass | |
from io import StringIO | |
from typing import Iterator, Self | |
import os | |
import sys | |
@dataclass(order=True, slots=True) | |
class ImportInner: | |
module: str | |
alias: str | None = None | |
def __str__(self): | |
return f"{self.module} as {self.alias}" if self.alias else self.module | |
@classmethod | |
def try_parse(cls, line: str) -> Self | None: | |
with suppress(ValueError, AssertionError): | |
module, _as, alias = line.split() | |
assert _as == "as" | |
return cls(module, alias) | |
with suppress(ValueError, AssertionError): | |
(module,) = line.split() | |
return cls(module) | |
return None | |
@dataclass(order=True, slots=True) | |
class Import: | |
inner: ImportInner | |
def __str__(self): | |
return f"import {self.inner}" | |
@classmethod | |
def try_parse(cls, line: str) -> Self | None: | |
with suppress(ValueError, AssertionError): | |
_import, line = line.split(maxsplit=1) | |
assert _import == "import" | |
inner = ImportInner.try_parse(line) | |
assert inner is not None | |
return cls(inner) | |
return None | |
@dataclass(order=True, slots=True) | |
class From: | |
module: str | |
imports: list[ImportInner] | |
def __str__(self): | |
return f"from {self.module} import {", ".join(map(str, self.imports))}" | |
@classmethod | |
def try_parse(cls, line: str, lines: Iterator[str]) -> Self | None: | |
with suppress(ValueError, AssertionError): | |
_from, module, _import, _paren = line.split() | |
assert _from == "from" | |
assert _import == "import" | |
assert _paren == "(" | |
imports = [] | |
for line in lines: | |
line = line.strip() | |
if line == ")": | |
break | |
imp = ImportInner.try_parse(line.rstrip(",")) | |
assert imp is not None | |
imports.append(imp) | |
imports.sort() | |
return cls(module, imports) | |
with suppress(ValueError, AssertionError): | |
_from, module, _import, import_part = line.split(maxsplit=3) | |
assert _from == "from" | |
assert _import == "import" | |
imports = [] | |
for part in import_part.split(","): | |
imp = ImportInner.try_parse(part.strip()) | |
assert imp is not None | |
imports.append(imp) | |
imports.sort() | |
return cls(module, imports) | |
return None | |
def sort_imports(filename: str) -> None: | |
with open(filename, "r+") as infile: | |
out = StringIO() | |
indent = "" | |
from_imports = [] | |
imports = [] | |
def flush_imports(): | |
from_imports.sort() | |
imports.sort() | |
for from_imp in from_imports: | |
print(indent, from_imp, sep="", file=out) | |
for imp in imports: | |
print(indent, imp, sep="", file=out) | |
from_imports.clear() | |
imports.clear() | |
def update_indent(line: str): | |
nonlocal indent | |
new_indent = line[: -len(line.lstrip())] | |
if new_indent != indent: | |
flush_imports() | |
indent = new_indent | |
for line in infile: | |
if from_imp := From.try_parse(line, infile): | |
update_indent(line) | |
from_imports.append(from_imp) | |
elif imp := Import.try_parse(line): | |
update_indent(line) | |
imports.append(imp) | |
else: | |
flush_imports() | |
out.write(line) | |
infile.seek(0) | |
infile.write(out.getvalue()) | |
infile.truncate() | |
def main(): | |
if len(sys.argv) < 2: | |
print("Usage: ./sort-imports.py file1.py file2.py ...") | |
sys.exit(1) | |
for filename in sys.argv[1:]: | |
sort_imports(filename) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
License: CC0-1.0