Skip to content

Instantly share code, notes, and snippets.

@vergenzt
Created March 24, 2025 15:10
Show Gist options
  • Save vergenzt/adeda6f8bff11715acb98e13a477d058 to your computer and use it in GitHub Desktop.
Save vergenzt/adeda6f8bff11715acb98e13a477d058 to your computer and use it in GitHub Desktop.
Fast Python script to write Alembic heads to file (for git conflicts on multiple heads)
#!/usr/bin/env python3
import ast
import os
import sys
from argparse import ArgumentParser, FileType
from collections import defaultdict
from configparser import ConfigParser
from graphlib import TopologicalSorter
from subprocess import check_call, check_output
from typing import Any, Iterable, Iterator, Tuple
def _get_lit_assignments(migration_path: str) -> Iterator[Tuple[str, Any]]:
with open(migration_path, "r") as migration_file:
mod: ast.Module = ast.parse(migration_file.read(), migration_path)
for stmt in mod.body:
match stmt:
case ast.Assign([ast.Name(var)], expr) | ast.AnnAssign(ast.Name(var), _, expr) if expr:
try:
yield var, ast.literal_eval(expr)
except ValueError:
pass
case _:
pass
def _get_alembic_dependency(migration_path: str) -> tuple[str, tuple[str, ...]]:
"""
Get revision ID & down_revision ID(s) for given migration file.
"""
lits = dict(_get_lit_assignments(migration_path))
try:
revision: str = lits["revision"]
down_revision: str | Iterable[str] | None = lits["down_revision"]
except KeyError as err:
raise ValueError(
f"Migration {migration_path} did not have const {err.args[0]!r} declaration"
)
down_revision_tup: tuple[str, ...]
match down_revision:
case _ if isinstance(down_revision, str):
down_revision_tup = (down_revision,)
case _ if isinstance(down_revision, Iterable):
down_revision_tup = tuple(down_revision)
case None:
down_revision_tup = ()
return revision, down_revision_tup
def get_alembic_heads(*migration_paths: str) -> str:
"""
Get head revision ID(s) for given migration files.
Uses `ast` module & plaintext reads for speed. (`alembic heads` takes ~2s vs this which takes ~200ms.)
"""
graph = defaultdict[str, set[str]](set)
for revision, prev_revisions in map(_get_alembic_dependency, migration_paths):
for prev_revision in prev_revisions:
# prev_revision -> next_revision(s)
graph[prev_revision].add(revision)
sorter = TopologicalSorter(graph)
sorter.prepare()
heads = sorter.get_ready()
if not heads:
raise ValueError("No migration heads detected")
if len(heads) > 1:
raise ValueError(f"Multiple migration heads detected: {heads}")
return "".join(head + "\n" for head in heads)
def main():
parser = ArgumentParser()
parser.add_argument(
"-c",
"--alembic-config",
help="Path to alembic.ini config file",
type=FileType("r"),
required=True,
)
parser.add_argument(
"--check",
action="store_true",
help="Check if head revision file is up-to-date and exit 1 if not (the default if $CI is set)",
default=bool(os.getenv("CI")),
)
args = parser.parse_args()
config = ConfigParser()
config.read_file(args.alembic_config)
script_location = config["alembic"]["script_location"]
# statically assume `version_locations` is `versions` for now, for simplicity
version_paths_glob = f"{script_location}/versions/*.py"
# only check migrations which are in git index
# (so that unstaged new migrations don't get included)
version_paths = check_output(["git", "ls-files", version_paths_glob], text=True).splitlines()
heads_text = get_alembic_heads(*version_paths)
heads_path = config["alembic_head_file"]["path"]
print(f"Writing head revision(s) to {heads_path}:", file=sys.stderr)
print(heads_text)
with open(heads_path, "w") as heads_file:
heads_file.write(heads_text)
if args.check:
check_call(["git", "diff", "--exit-code", heads_path])
else:
# assume we're in pre-commit hook; add any changes to the index for commit
check_call(["git", "add", heads_path])
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment