Skip to content

Instantly share code, notes, and snippets.

@milkey-mouse
Last active November 4, 2024 05:18
Show Gist options
  • Save milkey-mouse/ad54429ab562819530b4d336773453c2 to your computer and use it in GitHub Desktop.
Save milkey-mouse/ad54429ab562819530b4d336773453c2 to your computer and use it in GitHub Desktop.
Sort imports in Python files
#!/usr/bin/env python3
# sort-imports.py: Sort imports in Python files.
#
# Usage: ./sort-imports.py file1.py file2.py ...
#
# Sorts imports within "stanzas". A stanza is a block of consecutive import
# statements with the same indentation. Each stanza is sorted to contain first
# all `from` imports, sorted by module, then by imported names; then `import`
# imports, sorted by module.
from contextlib import suppress
from dataclasses import dataclass
from io import StringIO
from typing import Iterator, Self
import os
import sys
@dataclass(order=True, slots=True)
class ImportInner:
module: str
alias: str | None = None
def __str__(self):
return f"{self.module} as {self.alias}" if self.alias else self.module
@classmethod
def try_parse(cls, line: str) -> Self | None:
with suppress(ValueError, AssertionError):
module, _as, alias = line.split()
assert _as == "as"
return cls(module, alias)
with suppress(ValueError, AssertionError):
(module,) = line.split()
return cls(module)
return None
@dataclass(order=True, slots=True)
class Import:
inner: ImportInner
def __str__(self):
return f"import {self.inner}"
@classmethod
def try_parse(cls, line: str) -> Self | None:
with suppress(ValueError, AssertionError):
_import, line = line.split(maxsplit=1)
assert _import == "import"
inner = ImportInner.try_parse(line)
assert inner is not None
return cls(inner)
return None
@dataclass(order=True, slots=True)
class From:
module: str
imports: list[ImportInner]
def __str__(self):
return f"from {self.module} import {", ".join(map(str, self.imports))}"
@classmethod
def try_parse(cls, line: str, lines: Iterator[str]) -> Self | None:
with suppress(ValueError, AssertionError):
_from, module, _import, _paren = line.split()
assert _from == "from"
assert _import == "import"
assert _paren == "("
imports = []
for line in lines:
line = line.strip()
if line == ")":
break
imp = ImportInner.try_parse(line.rstrip(","))
assert imp is not None
imports.append(imp)
imports.sort()
return cls(module, imports)
with suppress(ValueError, AssertionError):
_from, module, _import, import_part = line.split(maxsplit=3)
assert _from == "from"
assert _import == "import"
imports = []
for part in import_part.split(","):
imp = ImportInner.try_parse(part.strip())
assert imp is not None
imports.append(imp)
imports.sort()
return cls(module, imports)
return None
def sort_imports(filename: str) -> None:
with open(filename, "r+") as infile:
out = StringIO()
indent = ""
from_imports = []
imports = []
def flush_imports():
from_imports.sort()
imports.sort()
for from_imp in from_imports:
print(indent, from_imp, sep="", file=out)
for imp in imports:
print(indent, imp, sep="", file=out)
from_imports.clear()
imports.clear()
def update_indent(line: str):
nonlocal indent
new_indent = line[: -len(line.lstrip())]
if new_indent != indent:
flush_imports()
indent = new_indent
for line in infile:
if from_imp := From.try_parse(line, infile):
update_indent(line)
from_imports.append(from_imp)
elif imp := Import.try_parse(line):
update_indent(line)
imports.append(imp)
else:
flush_imports()
out.write(line)
infile.seek(0)
infile.write(out.getvalue())
infile.truncate()
def main():
if len(sys.argv) < 2:
print("Usage: ./sort-imports.py file1.py file2.py ...")
sys.exit(1)
for filename in sys.argv[1:]:
sort_imports(filename)
if __name__ == "__main__":
main()
@milkey-mouse
Copy link
Author

License: CC0-1.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment