Last active
June 10, 2024 20:38
-
-
Save lagru/b085c9c6c23a952dd2a1022cdf1a9398 to your computer and use it in GitHub Desktop.
A small convenience script to checkout pull requests as a local branch and clean up once done
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
"""Checkout GitHub pull requests locally. | |
A small convenience script to checkout pull requests as a local branch | |
and clean up once done. | |
Inspired by git-pr [1]_ from Stéfan van der Walt. | |
.. [1] https://github.com/stefanv/git-tools/blob/10cd994c8737e5192ada06d850ecbdbe3f223e34/scripts/git-pr | |
""" | |
import sys | |
import re | |
import argparse | |
import subprocess | |
import urllib.request | |
import json | |
import shlex | |
import traceback | |
from contextlib import contextmanager | |
BASE_URL = "[email protected]:" | |
GITHUB_REST_URL = "https://api.github.com/repos/" | |
BLACKLIST_BRANCH_NAMES = ("main", "master") | |
BLACKLIST_REMOTE_NAMES = ("origin", "upstream", "upstream-writeable") | |
def red(text: str) -> str: | |
"""Wrap `text` with bold red ANSII escape code.""" | |
return f"\033[31;1m{text}\033[0m" | |
def blue(text) -> str: | |
"""Wrap `text` with bold red ANSII escape code.""" | |
return f"\033[34m{text}\033[0m" | |
def bold(text) -> str: | |
"""Wrap `text` with bold ANSII escape code.""" | |
return f"\033[1m{text}\033[0m" | |
def run(cmd: str, *args, check=True, show=False) -> str: | |
"""Run a command while handling printing. | |
`cmd` should only contain trusted input, while `args` might contain input | |
from untrusted sources, e.g. fetched data from the internet. | |
""" | |
if show is True: | |
joined_args = " ".join(shlex.quote(a) for a in args) | |
print(bold(f"$ {cmd} {joined_args}")) | |
cmd_parts = shlex.split(cmd) | |
if not cmd_parts or not cmd_parts[0]: | |
raise ValueError(f"cmd appears to be empty: {cmd_parts!r}") | |
result = subprocess.run( | |
cmd_parts + list(args), | |
check=check, | |
text=True, | |
stdout=subprocess.PIPE, | |
# Never use shell=True here to prevent shell injections, | |
# as args might contain input fetched from the web | |
shell=False, | |
) | |
output = result.stdout | |
if show is True and output: | |
print(result.stdout.strip()) | |
return result.stdout | |
def parse_command_line(): | |
parser = argparse.ArgumentParser( | |
description=__doc__, | |
formatter_class=argparse.RawDescriptionHelpFormatter, | |
) | |
parser.add_argument( | |
"pr_number", | |
metavar="NUMBER", | |
help="Number of the pull request to check out", | |
) | |
parser.add_argument( | |
"-d", | |
"--done", | |
action="store_true", | |
help="Delete pull request branch and remote if they exist", | |
) | |
parser.add_argument( | |
"--remote", | |
metavar="REMOTE", | |
dest="ref_remote", | |
default="upstream", | |
help="Remote from which pull requests are taken " "(default: 'upstream')", | |
) | |
parser.add_argument( | |
"--fallback", | |
dest="fallback_branch", | |
metavar="BRANCH", | |
default="main", | |
help="When deleting, switch to this branch before doing so (default: 'main')", | |
) | |
kwargs = vars(parser.parse_args()) | |
return kwargs | |
def remove_pr_branch(local_branch_name, *, fallback_branch, local_remote_name): | |
"""Remove the branch of a previously checked out pull request. | |
Given the `local_branch_name` of a previously checked out pull request, | |
delete it and potentially its `local_remote_name` if it's the last branch | |
associated with that remote. Switch to `fallback_branch` if the current | |
branch is the one to delete. | |
""" | |
if local_branch_name in BLACKLIST_BRANCH_NAMES: | |
raise RuntimeError("requested to delete blacklisted branch `main`") | |
if local_remote_name in BLACKLIST_REMOTE_NAMES: | |
raise RuntimeError( | |
f"requested to remove blacklisted remote `{local_remote_name}`" | |
) | |
run("git switch", fallback_branch, show=True) | |
run(f"git branch -D", local_branch_name, check=False, show=True) | |
branches = run("git branch -vv") | |
if local_remote_name not in branches: | |
run("git remote remove", local_remote_name, check=False, show=True) | |
@contextmanager | |
def handle_exceptions(): | |
"""Handle (un)expected exceptions in `main()`.""" | |
try: | |
yield | |
except (SystemExit, KeyboardInterrupt): | |
raise | |
except subprocess.CalledProcessError as error: | |
print(red(error)) | |
sys.exit(1) | |
except Exception: | |
print(red(traceback.format_exc()), file=sys.stderr) | |
sys.exit(1) | |
def main(*, pr_number: str, done: bool, ref_remote: str, fallback_branch: str): | |
"""Run the script. | |
Check `parse_command_line` for the meaning of the parameters. | |
""" | |
repo_url = run("git config --get", f"remote.{ref_remote}.url") | |
match = re.match(r"^.*?(?P<owner>[\w-]+)/(?P<repo>[\w-]+)\.git$", repo_url) | |
ref_owner = match["owner"] | |
ref_repo = match["repo"] | |
pr_url = f"{GITHUB_REST_URL}{ref_owner}/{ref_repo}/pulls/{pr_number}" | |
with urllib.request.urlopen(pr_url) as response: | |
html = response.read() | |
pr_data = json.loads(html) | |
pr_title = pr_data["title"] | |
pr_html_url = pr_data["html_url"] | |
remote_name = pr_data["head"]["user"]["login"] | |
branch_name = pr_data["head"]["ref"] | |
print(blue(f"{pr_title}\n{pr_html_url}\n{remote_name}:{branch_name}")) | |
local_remote_name = f"_{remote_name}" | |
local_branch_name = f"pr/{pr_number}_{branch_name}" | |
if done is True: | |
remove_pr_branch( | |
local_remote_name=local_remote_name, | |
fallback_branch=fallback_branch, | |
local_branch_name=local_branch_name, | |
) | |
else: | |
remotes = run("git remote") | |
if local_remote_name not in remotes: | |
run( | |
"git remote add", | |
local_remote_name, | |
f"{BASE_URL}{remote_name}/{ref_repo}", | |
show=True, | |
) | |
run("git fetch", local_remote_name, branch_name, show=True) | |
branches = run("git branch") | |
if local_branch_name not in branches: | |
run( | |
"git checkout -b", | |
local_branch_name, | |
f"{local_remote_name}/{branch_name}", | |
show=True, | |
) | |
run("git config", f"branch.{local_branch_name}.description", pr_html_url) | |
else: | |
run("git switch", local_branch_name, show=True) | |
run("git merge", f"{local_remote_name}/{branch_name}", show=True) | |
if __name__ == "__main__": | |
with handle_exceptions(): | |
kwargs = parse_command_line() | |
main(**kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment