Skip to content

Instantly share code, notes, and snippets.

@code-yeongyu
Created November 3, 2025 08:26
Show Gist options
  • Select an option

  • Save code-yeongyu/b5d5b14c89149fe2396a8cef5d92e505 to your computer and use it in GitHub Desktop.

Select an option

Save code-yeongyu/b5d5b14c89149fe2396a8cef5d92e505 to your computer and use it in GitHub Desktop.
Make sure the commits for the PRs you opened are rebased on top of the Base Branch. Git worktrees, nested branches, and branch names other than 'main' are all fully supported because it fetches the main branch name directly from the gh cli. It's a PEP 723 format script and even includes the 'rich' dependency, so it's convenient and looks great, …
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "typer>=0.12.0",
# "rich>=13.0.0",
# ]
# ///
# pyright: reportMissingImports=false
"""
Git branch rebase automation tool.
Rebases multiple branches on origin/main and force pushes with worktree support.
"""
import asyncio
import json
import os
import re
import subprocess
from dataclasses import dataclass
from graphlib import TopologicalSorter
from pathlib import Path
import typer
from rich.console import Console
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TaskID, TextColumn
from rich.prompt import Confirm
from rich.table import Table
from rich.text import Text
app = typer.Typer(help="Rebase multiple branches on origin/main and force push")
console = Console()
PR_URL_PATTERN = re.compile(r"^https://github\.com/[^/]+/[^/]+/pull/(\d+)$")
@dataclass
class PRInfo:
"""Pull request information."""
number: int
title: str
url: str
base_branch: str
@dataclass
class BranchInfo:
"""Branch with optional PR and rebase information."""
name: str
pr: PRInfo | None = None
base_branch: str = "main"
needs_rebase: bool = True
commits_behind: int = 0
class GitError(Exception):
"""Git operation error."""
pass
async def run_command(*args: str, cwd: Path | None = None, check: bool = True) -> tuple[int, str, str]:
"""Run shell command asynchronously."""
proc = await asyncio.create_subprocess_exec(
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
)
stdout_bytes, stderr_bytes = await proc.communicate()
stdout = stdout_bytes.decode().strip()
stderr = stderr_bytes.decode().strip()
if check and proc.returncode != 0:
raise GitError(f"Command failed: {' '.join(args)}\n{stderr}")
return proc.returncode or 0, stdout, stderr
def run_command_sync(*args: str, cwd: Path | None = None, check: bool = True) -> tuple[int, str, str]:
"""Run shell command synchronously."""
result = subprocess.run(
args,
check=False,
cwd=cwd,
capture_output=True,
text=True,
)
if check and result.returncode != 0:
raise GitError(f"Command failed: {' '.join(args)}\n{result.stderr}")
return result.returncode, result.stdout.strip(), result.stderr.strip()
def is_git_repo() -> bool:
"""Check if current directory is inside a git repository."""
try:
run_command_sync("git", "rev-parse", "--is-inside-work-tree")
return True
except GitError:
return False
def get_current_branch() -> str:
"""Get current branch name or commit hash if detached."""
_, stdout, _ = run_command_sync("git", "branch", "--show-current")
if stdout:
return stdout
_, stdout, _ = run_command_sync("git", "rev-parse", "HEAD")
console.print(f"[yellow]⚠ Warning: In detached HEAD state ({stdout[:7]})[/yellow]")
return stdout
def get_worktree_path(branch: str) -> Path | None:
"""Get worktree path for a branch if it exists."""
try:
_, stdout, _ = run_command_sync("git", "worktree", "list", "--porcelain")
worktree_path = None
for line in stdout.split("\n"):
if line.startswith("worktree "):
worktree_path = Path(line.split(" ", 1)[1])
elif line.startswith("branch ") and line.endswith(f"refs/heads/{branch}"):
return worktree_path
except GitError:
pass
return None
async def get_default_branch() -> str:
"""Get repository default branch using gh CLI."""
_, stdout, _ = await run_command(
"gh", "repo", "view", "--json", "defaultBranchRef", "-q", ".defaultBranchRef.name"
)
if not stdout:
raise GitError("Failed to detect default branch")
return stdout
async def get_my_prs(base_branch: str) -> list[str]:
"""Get list of my PR head branches targeting the base branch."""
panel = Panel(
f"[cyan]Default branch: [bold]{base_branch}[/bold][/cyan]",
title="[blue]πŸ” Collecting PRs[/blue]",
border_style="blue",
)
console.print(panel)
_, stdout, _ = await run_command(
"gh",
"pr",
"list",
"--author",
"@me",
"--base",
base_branch,
"--state",
"open",
"--json",
"headRefName",
"-q",
".[].headRefName",
)
if not stdout:
console.print("[yellow]⚠ No PRs found[/yellow]")
return []
branches = stdout.split("\n")
console.print(f"[green]βœ“ Found {len(branches)} PR(s)[/green]\n")
return branches
async def get_pr_info(branch: str) -> PRInfo | None:
"""Get PR information for a branch."""
try:
_, stdout, _ = await run_command(
"gh",
"pr",
"list",
"--head",
branch,
"--state",
"open",
"--json",
"number,title,url,baseRefName",
"--limit",
"1",
check=False,
)
if not stdout or stdout == "[]":
return None
data = json.loads(stdout)
if not data:
return None
pr = data[0]
return PRInfo(
number=pr["number"],
title=pr["title"],
url=pr["url"],
base_branch=pr["baseRefName"],
)
except (GitError, json.JSONDecodeError, KeyError, IndexError):
return None
async def check_rebase_needed(branch: str, base_branch: str) -> tuple[bool, int]:
"""
Check if branch needs rebasing against base branch.
Returns (needs_rebase, commits_behind).
"""
try:
_, stdout, _ = await run_command(
"git",
"rev-list",
"--count",
f"origin/{branch}..origin/{base_branch}",
check=False,
)
if not stdout:
return True, 0
commits_behind = int(stdout)
return commits_behind > 0, commits_behind
except (GitError, ValueError):
return True, 0
async def find_closest_branch(branch: str, default_branch: str, candidate_branches: set[str]) -> str | None:
"""
Find the closest branch using merge-base heuristic.
Returns the branch name with minimum commit distance, or None if default is closest.
"""
if not candidate_branches:
return None
distances: dict[str, int] = {}
for candidate in candidate_branches:
if candidate == branch:
continue
try:
_, stdout, _ = await run_command(
"git",
"rev-list",
"--count",
f"origin/{branch}",
f"^origin/{candidate}",
check=False,
)
if stdout:
distances[candidate] = int(stdout)
except (GitError, ValueError):
continue
if not distances:
return None
closest = min(distances, key=distances.get)
closest_distance = distances[closest]
default_distance = distances.get(default_branch, float("inf"))
if closest_distance < default_distance and closest != default_branch:
return closest
return None
def extract_branch_from_pr_url(url: str) -> str | None:
"""Extract branch name from GitHub PR URL."""
match = PR_URL_PATTERN.match(url)
if not match:
return None
pr_number = match.group(1)
console.print(f"[blue]πŸ”— Extracting branch from PR #{pr_number}...[/blue]")
try:
_, stdout, _ = run_command_sync("gh", "pr", "view", url, "--json", "headRefName", "-q", ".headRefName")
return stdout if stdout else None
except GitError:
console.print("[red]βœ— Error: Failed to extract branch from PR[/red]")
return None
async def has_uncommitted_changes() -> bool:
"""Check if there are uncommitted changes in the working directory."""
returncode, stdout, _ = await run_command("git", "status", "--porcelain", check=False)
return returncode == 0 and bool(stdout.strip())
async def stash_changes() -> bool:
"""Stash uncommitted changes. Returns True if stash was created."""
if not await has_uncommitted_changes():
return False
console.print("[cyan]πŸ’Ύ Stashing uncommitted changes...[/cyan]")
returncode, _, _ = await run_command(
"git", "stash", "push", "-u", "-m", "updatebranch auto-stash", check=False
)
if returncode == 0:
console.print("[green]βœ“ Changes stashed[/green]")
return True
console.print("[yellow]⚠ Failed to stash changes[/yellow]")
return False
async def pop_stash() -> bool:
"""Pop the most recent stash. Returns True if successful."""
console.print("[cyan]πŸ“€ Restoring stashed changes...[/cyan]")
returncode, _, stderr = await run_command("git", "stash", "pop", check=False)
if returncode == 0:
console.print("[green]βœ“ Changes restored[/green]")
return True
if "No stash entries found" in stderr:
return True
console.print("[yellow]⚠ Failed to restore stashed changes[/yellow]")
console.print(f"[dim]{stderr}[/dim]")
return False
async def checkout_branch(branch: str, original_dir: Path) -> tuple[bool, Path | None]:
"""
Checkout a branch, handling worktree cases.
Returns (success, worktree_path_if_used).
"""
returncode, _, _ = await run_command("git", "checkout", branch, check=False)
if returncode == 0:
console.print("\t[green]βœ“ Checked out in current location[/green]")
return True, None
worktree_path = get_worktree_path(branch)
if worktree_path:
console.print(
f"\t[cyan]πŸ“‚ Branch in worktree: {worktree_path}\n"
f"\t Moving to worktree...[/cyan]"
)
try:
os.chdir(worktree_path)
console.print("\t[green]βœ“ Now in worktree[/green]")
return True, worktree_path
except OSError:
console.print("\t[red]βœ— Cannot access worktree[/red]")
return False, None
console.print("\t[red]βœ— Checkout failed[/red]")
return False, None
async def rebase_branch(branch: str, base_branch: str) -> tuple[bool, str]:
"""Rebase branch on its base branch. Returns (success, error_message)."""
returncode, _, stderr = await run_command("git", "pull", "--rebase", "origin", base_branch, check=False)
if returncode != 0:
console.print(f"\t[red]βœ— Rebase failed on origin/{base_branch}[/red]")
if stderr:
console.print(f"\t[dim red]{stderr}[/dim red]")
await run_command("git", "rebase", "--abort", check=False)
return False, stderr
console.print(f"\t[green]βœ“ Rebased on origin/{base_branch}[/green]")
return True, ""
async def force_push() -> bool:
"""Force push with lease."""
returncode, _, _ = await run_command("git", "push", "--force-with-lease", check=False)
if returncode != 0:
console.print("\t[red]βœ— Force push failed[/red]")
return False
console.print("\t[green]βœ“ Force pushed with lease[/green]")
return True
async def process_branch(branch_info: BranchInfo, original_dir: Path) -> tuple[bool, str]:
"""Process a single branch: checkout, rebase, push. Returns (success, error_message)."""
branch = branch_info.name
success, worktree_path = await checkout_branch(branch, original_dir)
if not success:
return False, "Checkout failed"
try:
rebase_success, rebase_error = await rebase_branch(branch, branch_info.base_branch)
if not rebase_success:
return False, rebase_error
if not await force_push():
return False, "Force push failed"
return True, ""
finally:
if worktree_path:
os.chdir(original_dir)
async def collect_branch_info(branches: list[str], default_base_branch: str) -> list[BranchInfo]:
"""Collect PR information and rebase status for all branches concurrently."""
with console.status("[cyan]πŸ”„ Fetching branch information...[/cyan]", spinner="dots"):
pr_tasks = [get_pr_info(branch) for branch in branches]
pr_infos = await asyncio.gather(*pr_tasks)
branch_set = set(branches)
detected_bases: list[str] = []
for branch, pr_info in zip(branches, pr_infos):
if pr_info:
detected_bases.append(pr_info.base_branch)
else:
closest = await find_closest_branch(branch, default_base_branch, branch_set)
detected_base = closest if closest else default_base_branch
detected_bases.append(detected_base)
if closest and closest != default_base_branch:
console.print(
f"[dim] β„Ή {branch}: No PR found, detected base β†’ {closest}[/dim]"
)
rebase_tasks = []
for branch, base in zip(branches, detected_bases):
rebase_tasks.append(check_rebase_needed(branch, base))
rebase_infos = await asyncio.gather(*rebase_tasks)
return [
BranchInfo(
name=branch,
pr=pr_info,
base_branch=base,
needs_rebase=needs_rebase,
commits_behind=commits_behind,
)
for branch, pr_info, base, (needs_rebase, commits_behind) in zip(
branches, pr_infos, detected_bases, rebase_infos
)
]
def sort_branches_by_dependency(branch_infos: list[BranchInfo]) -> list[BranchInfo]:
"""
Sort branches by dependency using topological sort.
If branch B depends on branch A (B's base is A), then A comes before B.
"""
branch_names = {info.name for info in branch_infos}
branch_map = {info.name: info for info in branch_infos}
graph: dict[str, set[str]] = {}
for info in branch_infos:
if info.base_branch in branch_names:
graph[info.name] = {info.base_branch}
else:
graph[info.name] = set()
try:
sorter = TopologicalSorter(graph)
sorted_names = list(sorter.static_order())
return [branch_map[name] for name in sorted_names]
except Exception as e:
console.print(f"[yellow]⚠ Warning: Circular dependency detected, using original order[/yellow]")
console.print(f"[dim]{e}[/dim]")
return branch_infos
def display_branches(branch_infos: list[BranchInfo], default_branch: str) -> None:
"""Display branch list with PR and rebase status information."""
table = Table(
title="[bold cyan]πŸ“‹ Branches to Process[/bold cyan]",
show_header=True,
header_style="bold magenta",
border_style="cyan",
title_style="bold cyan",
)
table.add_column("Branch", style="blue", no_wrap=False)
table.add_column("Base", style="cyan", no_wrap=False)
table.add_column("Status", justify="center", style="white")
table.add_column("PR #", justify="right", style="yellow")
table.add_column("Title", style="white")
table.add_column("URL", style="dim cyan", no_wrap=False)
for branch_info in branch_infos:
if branch_info.needs_rebase:
status = f"[yellow]{branch_info.commits_behind} behind[/yellow]"
else:
status = "[green]βœ“ Up to date[/green]"
if branch_info.base_branch == default_branch:
base_display = f"[dim cyan]{branch_info.base_branch}[/dim cyan]"
else:
base_display = f"[bold magenta]{branch_info.base_branch}[/bold magenta]"
if branch_info.pr:
table.add_row(
branch_info.name,
base_display,
status,
f"#{branch_info.pr.number}",
branch_info.pr.title,
branch_info.pr.url,
)
else:
table.add_row(
branch_info.name,
base_display,
status,
"[dim]N/A[/dim]",
"[yellow]No PR found[/yellow]",
"[dim]N/A[/dim]",
)
console.print(table)
@app.command()
def main(
branches: list[str] = typer.Argument(
None,
help="Branch names or PR URLs to process",
),
all_prs: bool = typer.Option(
False,
"--all",
help="Process all your PRs based on the default branch",
),
yes: bool = typer.Option(
False,
"--yes",
"-y",
help="Skip confirmation prompt",
),
) -> None:
"""
Rebase multiple branches on origin/main and force push.
Supports worktree-checked-out branches and automatically handles them.
"""
if not is_git_repo():
console.print("[red]βœ— Error: Not in a git repository[/red]")
raise typer.Exit(1)
original_dir = Path.cwd()
original_branch = get_current_branch()
async def async_main() -> None:
stash_created = False
try:
stash_created = await stash_changes()
if stash_created:
console.print()
branch_list = list(branches) if branches else []
default_branch = "main"
if all_prs:
default_branch = await get_default_branch()
my_prs = await get_my_prs(default_branch)
branch_list.extend(my_prs)
resolved_branches = []
for item in branch_list:
if item.startswith("https://github.com/"):
branch = extract_branch_from_pr_url(item)
if branch:
resolved_branches.append(branch)
else:
resolved_branches.append(item)
unique_branches = sorted(set(resolved_branches))
if not unique_branches:
panel = Panel(
"[bold]Usage:[/bold] updatebranch [OPTIONS] [branch1|PR_URL1] ...\n\n"
"[bold cyan]Options:[/bold cyan]\n"
"\t[yellow]--all[/yellow] Process all your PRs\n"
"\t[yellow]-y, --yes[/yellow] Skip confirmation\n\n"
"[bold cyan]Examples:[/bold cyan]\n"
"\t[dim]updatebranch --all[/dim]\n"
"\t[dim]updatebranch --all -y[/dim]\n"
"\t[dim]updatebranch feature/fix-bug[/dim]\n"
"\t[dim]updatebranch -y https://github.com/owner/repo/pull/123[/dim]",
title="[red]❌ No branches specified[/red]",
border_style="red",
)
console.print(panel)
raise typer.Exit(1)
fetch_panel = Panel(
f"[cyan]Fetching latest [bold]origin/{default_branch}[/bold]...[/cyan]",
title="[blue]πŸ”„ Updating[/blue]",
border_style="blue",
)
console.print(fetch_panel)
await run_command("git", "fetch", "origin", default_branch)
console.print("[green]βœ“ Fetch completed[/green]\n")
branch_infos = await collect_branch_info(unique_branches, default_branch)
unique_bases = {info.base_branch for info in branch_infos}
if len(unique_bases) > 1 or default_branch not in unique_bases:
console.print(f"[cyan]πŸ”„ Fetching additional base branches...[/cyan]")
fetch_tasks = [
run_command("git", "fetch", "origin", base, check=False)
for base in unique_bases
if base != default_branch
]
if fetch_tasks:
await asyncio.gather(*fetch_tasks)
console.print("[green]βœ“ All base branches fetched[/green]\n")
display_branches(branch_infos, default_branch)
branches_to_process = [info for info in branch_infos if info.needs_rebase]
skipped_count = len(branch_infos) - len(branches_to_process)
if skipped_count > 0:
console.print(f"\n[dim]β„Ή Skipping {skipped_count} branch(es) already up to date[/dim]")
if not branches_to_process:
console.print("\n[green]βœ“ All branches are up to date, nothing to do[/green]")
raise typer.Exit(0)
branches_to_process = sort_branches_by_dependency(branches_to_process)
console.print(f"[dim]πŸ“Š Processing order: {' β†’ '.join(info.name for info in branches_to_process)}[/dim]\n")
if not yes:
console.print()
warning = Panel(
f"[bold yellow]⚠ This will rebase {len(branches_to_process)} branch(es) on their respective base branches and force push.[/bold yellow]",
border_style="yellow",
)
console.print(warning)
if not Confirm.ask("[yellow]Continue?[/yellow]", default=True):
console.print("[red]❌ Aborted by user[/red]")
raise typer.Exit(0)
else:
console.print("\n[green]βœ“ --yes flag detected, skipping confirmation[/green]\n")
with Progress(console=console) as progress:
task: TaskID = progress.add_task("[cyan]Processing branches...", total=len(branches_to_process))
for idx, branch_info in enumerate(branches_to_process, 1):
header = Panel(
f"[bold]{branch_info.name}[/bold]",
title=f"[cyan]πŸ”§ [{idx}/{len(branches_to_process)}][/cyan]",
border_style="green",
)
console.print(header)
success, error_msg = await process_branch(branch_info, original_dir)
result_text = Text()
if success:
result_text.append("βœ“ Successfully processed ", style="green bold")
result_text.append(branch_info.name, style="green")
else:
result_text.append("βœ— Failed to process ", style="red bold")
result_text.append(branch_info.name, style="red")
if error_msg:
result_text.append(f"\n\t{error_msg}", style="dim red")
console.print(result_text)
console.print()
progress.update(
task,
advance=1,
description=f"[cyan]Processing branches... ({idx}/{len(branches_to_process)})[/cyan]",
)
summary = Panel(
f"[bold green]βœ“ Processed {len(branches_to_process)} branch(es)[/bold green]"
+ (f"\n[dim]Skipped {skipped_count} already up to date[/dim]" if skipped_count > 0 else ""),
title="[green]βœ… Complete[/green]",
border_style="green",
)
console.print(summary)
finally:
cleanup = Panel(
f"[cyan]Returning to:[/cyan] [bold]{original_branch}[/bold]",
title="[blue]🧹 Cleanup[/blue]",
border_style="blue",
)
console.print(cleanup)
os.chdir(original_dir)
try:
run_command_sync("git", "checkout", original_branch, check=False)
except GitError:
console.print("[yellow]⚠ Warning: Could not return to original branch[/yellow]")
if stash_created:
console.print()
await pop_stash()
asyncio.run(async_main())
if __name__ == "__main__":
app()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment