Created
March 24, 2025 15:10
-
-
Save vergenzt/adeda6f8bff11715acb98e13a477d058 to your computer and use it in GitHub Desktop.
Fast Python script to write Alembic heads to file (for git conflicts on multiple heads)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import ast | |
import os | |
import sys | |
from argparse import ArgumentParser, FileType | |
from collections import defaultdict | |
from configparser import ConfigParser | |
from graphlib import TopologicalSorter | |
from subprocess import check_call, check_output | |
from typing import Any, Iterable, Iterator, Tuple | |
def _get_lit_assignments(migration_path: str) -> Iterator[Tuple[str, Any]]: | |
with open(migration_path, "r") as migration_file: | |
mod: ast.Module = ast.parse(migration_file.read(), migration_path) | |
for stmt in mod.body: | |
match stmt: | |
case ast.Assign([ast.Name(var)], expr) | ast.AnnAssign(ast.Name(var), _, expr) if expr: | |
try: | |
yield var, ast.literal_eval(expr) | |
except ValueError: | |
pass | |
case _: | |
pass | |
def _get_alembic_dependency(migration_path: str) -> tuple[str, tuple[str, ...]]: | |
""" | |
Get revision ID & down_revision ID(s) for given migration file. | |
""" | |
lits = dict(_get_lit_assignments(migration_path)) | |
try: | |
revision: str = lits["revision"] | |
down_revision: str | Iterable[str] | None = lits["down_revision"] | |
except KeyError as err: | |
raise ValueError( | |
f"Migration {migration_path} did not have const {err.args[0]!r} declaration" | |
) | |
down_revision_tup: tuple[str, ...] | |
match down_revision: | |
case _ if isinstance(down_revision, str): | |
down_revision_tup = (down_revision,) | |
case _ if isinstance(down_revision, Iterable): | |
down_revision_tup = tuple(down_revision) | |
case None: | |
down_revision_tup = () | |
return revision, down_revision_tup | |
def get_alembic_heads(*migration_paths: str) -> str: | |
""" | |
Get head revision ID(s) for given migration files. | |
Uses `ast` module & plaintext reads for speed. (`alembic heads` takes ~2s vs this which takes ~200ms.) | |
""" | |
graph = defaultdict[str, set[str]](set) | |
for revision, prev_revisions in map(_get_alembic_dependency, migration_paths): | |
for prev_revision in prev_revisions: | |
# prev_revision -> next_revision(s) | |
graph[prev_revision].add(revision) | |
sorter = TopologicalSorter(graph) | |
sorter.prepare() | |
heads = sorter.get_ready() | |
if not heads: | |
raise ValueError("No migration heads detected") | |
if len(heads) > 1: | |
raise ValueError(f"Multiple migration heads detected: {heads}") | |
return "".join(head + "\n" for head in heads) | |
def main(): | |
parser = ArgumentParser() | |
parser.add_argument( | |
"-c", | |
"--alembic-config", | |
help="Path to alembic.ini config file", | |
type=FileType("r"), | |
required=True, | |
) | |
parser.add_argument( | |
"--check", | |
action="store_true", | |
help="Check if head revision file is up-to-date and exit 1 if not (the default if $CI is set)", | |
default=bool(os.getenv("CI")), | |
) | |
args = parser.parse_args() | |
config = ConfigParser() | |
config.read_file(args.alembic_config) | |
script_location = config["alembic"]["script_location"] | |
# statically assume `version_locations` is `versions` for now, for simplicity | |
version_paths_glob = f"{script_location}/versions/*.py" | |
# only check migrations which are in git index | |
# (so that unstaged new migrations don't get included) | |
version_paths = check_output(["git", "ls-files", version_paths_glob], text=True).splitlines() | |
heads_text = get_alembic_heads(*version_paths) | |
heads_path = config["alembic_head_file"]["path"] | |
print(f"Writing head revision(s) to {heads_path}:", file=sys.stderr) | |
print(heads_text) | |
with open(heads_path, "w") as heads_file: | |
heads_file.write(heads_text) | |
if args.check: | |
check_call(["git", "diff", "--exit-code", heads_path]) | |
else: | |
# assume we're in pre-commit hook; add any changes to the index for commit | |
check_call(["git", "add", heads_path]) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment