Created
November 3, 2025 08:26
-
-
Save code-yeongyu/b5d5b14c89149fe2396a8cef5d92e505 to your computer and use it in GitHub Desktop.
Make sure the commits for the PRs you opened are rebased on top of the Base Branch. Git worktrees, nested branches, and branch names other than 'main' are all fully supported because it fetches the main branch name directly from the gh cli. It's a PEP 723 format script and even includes the 'rich' dependency, so it's convenient and looks great, β¦
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env -S uv run --script | |
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = [ | |
| # "typer>=0.12.0", | |
| # "rich>=13.0.0", | |
| # ] | |
| # /// | |
| # pyright: reportMissingImports=false | |
| """ | |
| Git branch rebase automation tool. | |
| Rebases multiple branches on origin/main and force pushes with worktree support. | |
| """ | |
| import asyncio | |
| import json | |
| import os | |
| import re | |
| import subprocess | |
| from dataclasses import dataclass | |
| from graphlib import TopologicalSorter | |
| from pathlib import Path | |
| import typer | |
| from rich.console import Console | |
| from rich.panel import Panel | |
| from rich.progress import Progress, SpinnerColumn, TaskID, TextColumn | |
| from rich.prompt import Confirm | |
| from rich.table import Table | |
| from rich.text import Text | |
| app = typer.Typer(help="Rebase multiple branches on origin/main and force push") | |
| console = Console() | |
| PR_URL_PATTERN = re.compile(r"^https://github\.com/[^/]+/[^/]+/pull/(\d+)$") | |
| @dataclass | |
| class PRInfo: | |
| """Pull request information.""" | |
| number: int | |
| title: str | |
| url: str | |
| base_branch: str | |
| @dataclass | |
| class BranchInfo: | |
| """Branch with optional PR and rebase information.""" | |
| name: str | |
| pr: PRInfo | None = None | |
| base_branch: str = "main" | |
| needs_rebase: bool = True | |
| commits_behind: int = 0 | |
| class GitError(Exception): | |
| """Git operation error.""" | |
| pass | |
| async def run_command(*args: str, cwd: Path | None = None, check: bool = True) -> tuple[int, str, str]: | |
| """Run shell command asynchronously.""" | |
| proc = await asyncio.create_subprocess_exec( | |
| *args, | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.PIPE, | |
| cwd=cwd, | |
| ) | |
| stdout_bytes, stderr_bytes = await proc.communicate() | |
| stdout = stdout_bytes.decode().strip() | |
| stderr = stderr_bytes.decode().strip() | |
| if check and proc.returncode != 0: | |
| raise GitError(f"Command failed: {' '.join(args)}\n{stderr}") | |
| return proc.returncode or 0, stdout, stderr | |
| def run_command_sync(*args: str, cwd: Path | None = None, check: bool = True) -> tuple[int, str, str]: | |
| """Run shell command synchronously.""" | |
| result = subprocess.run( | |
| args, | |
| check=False, | |
| cwd=cwd, | |
| capture_output=True, | |
| text=True, | |
| ) | |
| if check and result.returncode != 0: | |
| raise GitError(f"Command failed: {' '.join(args)}\n{result.stderr}") | |
| return result.returncode, result.stdout.strip(), result.stderr.strip() | |
| def is_git_repo() -> bool: | |
| """Check if current directory is inside a git repository.""" | |
| try: | |
| run_command_sync("git", "rev-parse", "--is-inside-work-tree") | |
| return True | |
| except GitError: | |
| return False | |
| def get_current_branch() -> str: | |
| """Get current branch name or commit hash if detached.""" | |
| _, stdout, _ = run_command_sync("git", "branch", "--show-current") | |
| if stdout: | |
| return stdout | |
| _, stdout, _ = run_command_sync("git", "rev-parse", "HEAD") | |
| console.print(f"[yellow]β Warning: In detached HEAD state ({stdout[:7]})[/yellow]") | |
| return stdout | |
| def get_worktree_path(branch: str) -> Path | None: | |
| """Get worktree path for a branch if it exists.""" | |
| try: | |
| _, stdout, _ = run_command_sync("git", "worktree", "list", "--porcelain") | |
| worktree_path = None | |
| for line in stdout.split("\n"): | |
| if line.startswith("worktree "): | |
| worktree_path = Path(line.split(" ", 1)[1]) | |
| elif line.startswith("branch ") and line.endswith(f"refs/heads/{branch}"): | |
| return worktree_path | |
| except GitError: | |
| pass | |
| return None | |
| async def get_default_branch() -> str: | |
| """Get repository default branch using gh CLI.""" | |
| _, stdout, _ = await run_command( | |
| "gh", "repo", "view", "--json", "defaultBranchRef", "-q", ".defaultBranchRef.name" | |
| ) | |
| if not stdout: | |
| raise GitError("Failed to detect default branch") | |
| return stdout | |
| async def get_my_prs(base_branch: str) -> list[str]: | |
| """Get list of my PR head branches targeting the base branch.""" | |
| panel = Panel( | |
| f"[cyan]Default branch: [bold]{base_branch}[/bold][/cyan]", | |
| title="[blue]π Collecting PRs[/blue]", | |
| border_style="blue", | |
| ) | |
| console.print(panel) | |
| _, stdout, _ = await run_command( | |
| "gh", | |
| "pr", | |
| "list", | |
| "--author", | |
| "@me", | |
| "--base", | |
| base_branch, | |
| "--state", | |
| "open", | |
| "--json", | |
| "headRefName", | |
| "-q", | |
| ".[].headRefName", | |
| ) | |
| if not stdout: | |
| console.print("[yellow]β No PRs found[/yellow]") | |
| return [] | |
| branches = stdout.split("\n") | |
| console.print(f"[green]β Found {len(branches)} PR(s)[/green]\n") | |
| return branches | |
| async def get_pr_info(branch: str) -> PRInfo | None: | |
| """Get PR information for a branch.""" | |
| try: | |
| _, stdout, _ = await run_command( | |
| "gh", | |
| "pr", | |
| "list", | |
| "--head", | |
| branch, | |
| "--state", | |
| "open", | |
| "--json", | |
| "number,title,url,baseRefName", | |
| "--limit", | |
| "1", | |
| check=False, | |
| ) | |
| if not stdout or stdout == "[]": | |
| return None | |
| data = json.loads(stdout) | |
| if not data: | |
| return None | |
| pr = data[0] | |
| return PRInfo( | |
| number=pr["number"], | |
| title=pr["title"], | |
| url=pr["url"], | |
| base_branch=pr["baseRefName"], | |
| ) | |
| except (GitError, json.JSONDecodeError, KeyError, IndexError): | |
| return None | |
| async def check_rebase_needed(branch: str, base_branch: str) -> tuple[bool, int]: | |
| """ | |
| Check if branch needs rebasing against base branch. | |
| Returns (needs_rebase, commits_behind). | |
| """ | |
| try: | |
| _, stdout, _ = await run_command( | |
| "git", | |
| "rev-list", | |
| "--count", | |
| f"origin/{branch}..origin/{base_branch}", | |
| check=False, | |
| ) | |
| if not stdout: | |
| return True, 0 | |
| commits_behind = int(stdout) | |
| return commits_behind > 0, commits_behind | |
| except (GitError, ValueError): | |
| return True, 0 | |
| async def find_closest_branch(branch: str, default_branch: str, candidate_branches: set[str]) -> str | None: | |
| """ | |
| Find the closest branch using merge-base heuristic. | |
| Returns the branch name with minimum commit distance, or None if default is closest. | |
| """ | |
| if not candidate_branches: | |
| return None | |
| distances: dict[str, int] = {} | |
| for candidate in candidate_branches: | |
| if candidate == branch: | |
| continue | |
| try: | |
| _, stdout, _ = await run_command( | |
| "git", | |
| "rev-list", | |
| "--count", | |
| f"origin/{branch}", | |
| f"^origin/{candidate}", | |
| check=False, | |
| ) | |
| if stdout: | |
| distances[candidate] = int(stdout) | |
| except (GitError, ValueError): | |
| continue | |
| if not distances: | |
| return None | |
| closest = min(distances, key=distances.get) | |
| closest_distance = distances[closest] | |
| default_distance = distances.get(default_branch, float("inf")) | |
| if closest_distance < default_distance and closest != default_branch: | |
| return closest | |
| return None | |
| def extract_branch_from_pr_url(url: str) -> str | None: | |
| """Extract branch name from GitHub PR URL.""" | |
| match = PR_URL_PATTERN.match(url) | |
| if not match: | |
| return None | |
| pr_number = match.group(1) | |
| console.print(f"[blue]π Extracting branch from PR #{pr_number}...[/blue]") | |
| try: | |
| _, stdout, _ = run_command_sync("gh", "pr", "view", url, "--json", "headRefName", "-q", ".headRefName") | |
| return stdout if stdout else None | |
| except GitError: | |
| console.print("[red]β Error: Failed to extract branch from PR[/red]") | |
| return None | |
| async def has_uncommitted_changes() -> bool: | |
| """Check if there are uncommitted changes in the working directory.""" | |
| returncode, stdout, _ = await run_command("git", "status", "--porcelain", check=False) | |
| return returncode == 0 and bool(stdout.strip()) | |
| async def stash_changes() -> bool: | |
| """Stash uncommitted changes. Returns True if stash was created.""" | |
| if not await has_uncommitted_changes(): | |
| return False | |
| console.print("[cyan]πΎ Stashing uncommitted changes...[/cyan]") | |
| returncode, _, _ = await run_command( | |
| "git", "stash", "push", "-u", "-m", "updatebranch auto-stash", check=False | |
| ) | |
| if returncode == 0: | |
| console.print("[green]β Changes stashed[/green]") | |
| return True | |
| console.print("[yellow]β Failed to stash changes[/yellow]") | |
| return False | |
| async def pop_stash() -> bool: | |
| """Pop the most recent stash. Returns True if successful.""" | |
| console.print("[cyan]π€ Restoring stashed changes...[/cyan]") | |
| returncode, _, stderr = await run_command("git", "stash", "pop", check=False) | |
| if returncode == 0: | |
| console.print("[green]β Changes restored[/green]") | |
| return True | |
| if "No stash entries found" in stderr: | |
| return True | |
| console.print("[yellow]β Failed to restore stashed changes[/yellow]") | |
| console.print(f"[dim]{stderr}[/dim]") | |
| return False | |
| async def checkout_branch(branch: str, original_dir: Path) -> tuple[bool, Path | None]: | |
| """ | |
| Checkout a branch, handling worktree cases. | |
| Returns (success, worktree_path_if_used). | |
| """ | |
| returncode, _, _ = await run_command("git", "checkout", branch, check=False) | |
| if returncode == 0: | |
| console.print("\t[green]β Checked out in current location[/green]") | |
| return True, None | |
| worktree_path = get_worktree_path(branch) | |
| if worktree_path: | |
| console.print( | |
| f"\t[cyan]π Branch in worktree: {worktree_path}\n" | |
| f"\t Moving to worktree...[/cyan]" | |
| ) | |
| try: | |
| os.chdir(worktree_path) | |
| console.print("\t[green]β Now in worktree[/green]") | |
| return True, worktree_path | |
| except OSError: | |
| console.print("\t[red]β Cannot access worktree[/red]") | |
| return False, None | |
| console.print("\t[red]β Checkout failed[/red]") | |
| return False, None | |
| async def rebase_branch(branch: str, base_branch: str) -> tuple[bool, str]: | |
| """Rebase branch on its base branch. Returns (success, error_message).""" | |
| returncode, _, stderr = await run_command("git", "pull", "--rebase", "origin", base_branch, check=False) | |
| if returncode != 0: | |
| console.print(f"\t[red]β Rebase failed on origin/{base_branch}[/red]") | |
| if stderr: | |
| console.print(f"\t[dim red]{stderr}[/dim red]") | |
| await run_command("git", "rebase", "--abort", check=False) | |
| return False, stderr | |
| console.print(f"\t[green]β Rebased on origin/{base_branch}[/green]") | |
| return True, "" | |
| async def force_push() -> bool: | |
| """Force push with lease.""" | |
| returncode, _, _ = await run_command("git", "push", "--force-with-lease", check=False) | |
| if returncode != 0: | |
| console.print("\t[red]β Force push failed[/red]") | |
| return False | |
| console.print("\t[green]β Force pushed with lease[/green]") | |
| return True | |
| async def process_branch(branch_info: BranchInfo, original_dir: Path) -> tuple[bool, str]: | |
| """Process a single branch: checkout, rebase, push. Returns (success, error_message).""" | |
| branch = branch_info.name | |
| success, worktree_path = await checkout_branch(branch, original_dir) | |
| if not success: | |
| return False, "Checkout failed" | |
| try: | |
| rebase_success, rebase_error = await rebase_branch(branch, branch_info.base_branch) | |
| if not rebase_success: | |
| return False, rebase_error | |
| if not await force_push(): | |
| return False, "Force push failed" | |
| return True, "" | |
| finally: | |
| if worktree_path: | |
| os.chdir(original_dir) | |
| async def collect_branch_info(branches: list[str], default_base_branch: str) -> list[BranchInfo]: | |
| """Collect PR information and rebase status for all branches concurrently.""" | |
| with console.status("[cyan]π Fetching branch information...[/cyan]", spinner="dots"): | |
| pr_tasks = [get_pr_info(branch) for branch in branches] | |
| pr_infos = await asyncio.gather(*pr_tasks) | |
| branch_set = set(branches) | |
| detected_bases: list[str] = [] | |
| for branch, pr_info in zip(branches, pr_infos): | |
| if pr_info: | |
| detected_bases.append(pr_info.base_branch) | |
| else: | |
| closest = await find_closest_branch(branch, default_base_branch, branch_set) | |
| detected_base = closest if closest else default_base_branch | |
| detected_bases.append(detected_base) | |
| if closest and closest != default_base_branch: | |
| console.print( | |
| f"[dim] βΉ {branch}: No PR found, detected base β {closest}[/dim]" | |
| ) | |
| rebase_tasks = [] | |
| for branch, base in zip(branches, detected_bases): | |
| rebase_tasks.append(check_rebase_needed(branch, base)) | |
| rebase_infos = await asyncio.gather(*rebase_tasks) | |
| return [ | |
| BranchInfo( | |
| name=branch, | |
| pr=pr_info, | |
| base_branch=base, | |
| needs_rebase=needs_rebase, | |
| commits_behind=commits_behind, | |
| ) | |
| for branch, pr_info, base, (needs_rebase, commits_behind) in zip( | |
| branches, pr_infos, detected_bases, rebase_infos | |
| ) | |
| ] | |
| def sort_branches_by_dependency(branch_infos: list[BranchInfo]) -> list[BranchInfo]: | |
| """ | |
| Sort branches by dependency using topological sort. | |
| If branch B depends on branch A (B's base is A), then A comes before B. | |
| """ | |
| branch_names = {info.name for info in branch_infos} | |
| branch_map = {info.name: info for info in branch_infos} | |
| graph: dict[str, set[str]] = {} | |
| for info in branch_infos: | |
| if info.base_branch in branch_names: | |
| graph[info.name] = {info.base_branch} | |
| else: | |
| graph[info.name] = set() | |
| try: | |
| sorter = TopologicalSorter(graph) | |
| sorted_names = list(sorter.static_order()) | |
| return [branch_map[name] for name in sorted_names] | |
| except Exception as e: | |
| console.print(f"[yellow]β Warning: Circular dependency detected, using original order[/yellow]") | |
| console.print(f"[dim]{e}[/dim]") | |
| return branch_infos | |
| def display_branches(branch_infos: list[BranchInfo], default_branch: str) -> None: | |
| """Display branch list with PR and rebase status information.""" | |
| table = Table( | |
| title="[bold cyan]π Branches to Process[/bold cyan]", | |
| show_header=True, | |
| header_style="bold magenta", | |
| border_style="cyan", | |
| title_style="bold cyan", | |
| ) | |
| table.add_column("Branch", style="blue", no_wrap=False) | |
| table.add_column("Base", style="cyan", no_wrap=False) | |
| table.add_column("Status", justify="center", style="white") | |
| table.add_column("PR #", justify="right", style="yellow") | |
| table.add_column("Title", style="white") | |
| table.add_column("URL", style="dim cyan", no_wrap=False) | |
| for branch_info in branch_infos: | |
| if branch_info.needs_rebase: | |
| status = f"[yellow]{branch_info.commits_behind} behind[/yellow]" | |
| else: | |
| status = "[green]β Up to date[/green]" | |
| if branch_info.base_branch == default_branch: | |
| base_display = f"[dim cyan]{branch_info.base_branch}[/dim cyan]" | |
| else: | |
| base_display = f"[bold magenta]{branch_info.base_branch}[/bold magenta]" | |
| if branch_info.pr: | |
| table.add_row( | |
| branch_info.name, | |
| base_display, | |
| status, | |
| f"#{branch_info.pr.number}", | |
| branch_info.pr.title, | |
| branch_info.pr.url, | |
| ) | |
| else: | |
| table.add_row( | |
| branch_info.name, | |
| base_display, | |
| status, | |
| "[dim]N/A[/dim]", | |
| "[yellow]No PR found[/yellow]", | |
| "[dim]N/A[/dim]", | |
| ) | |
| console.print(table) | |
| @app.command() | |
| def main( | |
| branches: list[str] = typer.Argument( | |
| None, | |
| help="Branch names or PR URLs to process", | |
| ), | |
| all_prs: bool = typer.Option( | |
| False, | |
| "--all", | |
| help="Process all your PRs based on the default branch", | |
| ), | |
| yes: bool = typer.Option( | |
| False, | |
| "--yes", | |
| "-y", | |
| help="Skip confirmation prompt", | |
| ), | |
| ) -> None: | |
| """ | |
| Rebase multiple branches on origin/main and force push. | |
| Supports worktree-checked-out branches and automatically handles them. | |
| """ | |
| if not is_git_repo(): | |
| console.print("[red]β Error: Not in a git repository[/red]") | |
| raise typer.Exit(1) | |
| original_dir = Path.cwd() | |
| original_branch = get_current_branch() | |
| async def async_main() -> None: | |
| stash_created = False | |
| try: | |
| stash_created = await stash_changes() | |
| if stash_created: | |
| console.print() | |
| branch_list = list(branches) if branches else [] | |
| default_branch = "main" | |
| if all_prs: | |
| default_branch = await get_default_branch() | |
| my_prs = await get_my_prs(default_branch) | |
| branch_list.extend(my_prs) | |
| resolved_branches = [] | |
| for item in branch_list: | |
| if item.startswith("https://github.com/"): | |
| branch = extract_branch_from_pr_url(item) | |
| if branch: | |
| resolved_branches.append(branch) | |
| else: | |
| resolved_branches.append(item) | |
| unique_branches = sorted(set(resolved_branches)) | |
| if not unique_branches: | |
| panel = Panel( | |
| "[bold]Usage:[/bold] updatebranch [OPTIONS] [branch1|PR_URL1] ...\n\n" | |
| "[bold cyan]Options:[/bold cyan]\n" | |
| "\t[yellow]--all[/yellow] Process all your PRs\n" | |
| "\t[yellow]-y, --yes[/yellow] Skip confirmation\n\n" | |
| "[bold cyan]Examples:[/bold cyan]\n" | |
| "\t[dim]updatebranch --all[/dim]\n" | |
| "\t[dim]updatebranch --all -y[/dim]\n" | |
| "\t[dim]updatebranch feature/fix-bug[/dim]\n" | |
| "\t[dim]updatebranch -y https://github.com/owner/repo/pull/123[/dim]", | |
| title="[red]β No branches specified[/red]", | |
| border_style="red", | |
| ) | |
| console.print(panel) | |
| raise typer.Exit(1) | |
| fetch_panel = Panel( | |
| f"[cyan]Fetching latest [bold]origin/{default_branch}[/bold]...[/cyan]", | |
| title="[blue]π Updating[/blue]", | |
| border_style="blue", | |
| ) | |
| console.print(fetch_panel) | |
| await run_command("git", "fetch", "origin", default_branch) | |
| console.print("[green]β Fetch completed[/green]\n") | |
| branch_infos = await collect_branch_info(unique_branches, default_branch) | |
| unique_bases = {info.base_branch for info in branch_infos} | |
| if len(unique_bases) > 1 or default_branch not in unique_bases: | |
| console.print(f"[cyan]π Fetching additional base branches...[/cyan]") | |
| fetch_tasks = [ | |
| run_command("git", "fetch", "origin", base, check=False) | |
| for base in unique_bases | |
| if base != default_branch | |
| ] | |
| if fetch_tasks: | |
| await asyncio.gather(*fetch_tasks) | |
| console.print("[green]β All base branches fetched[/green]\n") | |
| display_branches(branch_infos, default_branch) | |
| branches_to_process = [info for info in branch_infos if info.needs_rebase] | |
| skipped_count = len(branch_infos) - len(branches_to_process) | |
| if skipped_count > 0: | |
| console.print(f"\n[dim]βΉ Skipping {skipped_count} branch(es) already up to date[/dim]") | |
| if not branches_to_process: | |
| console.print("\n[green]β All branches are up to date, nothing to do[/green]") | |
| raise typer.Exit(0) | |
| branches_to_process = sort_branches_by_dependency(branches_to_process) | |
| console.print(f"[dim]π Processing order: {' β '.join(info.name for info in branches_to_process)}[/dim]\n") | |
| if not yes: | |
| console.print() | |
| warning = Panel( | |
| f"[bold yellow]β This will rebase {len(branches_to_process)} branch(es) on their respective base branches and force push.[/bold yellow]", | |
| border_style="yellow", | |
| ) | |
| console.print(warning) | |
| if not Confirm.ask("[yellow]Continue?[/yellow]", default=True): | |
| console.print("[red]β Aborted by user[/red]") | |
| raise typer.Exit(0) | |
| else: | |
| console.print("\n[green]β --yes flag detected, skipping confirmation[/green]\n") | |
| with Progress(console=console) as progress: | |
| task: TaskID = progress.add_task("[cyan]Processing branches...", total=len(branches_to_process)) | |
| for idx, branch_info in enumerate(branches_to_process, 1): | |
| header = Panel( | |
| f"[bold]{branch_info.name}[/bold]", | |
| title=f"[cyan]π§ [{idx}/{len(branches_to_process)}][/cyan]", | |
| border_style="green", | |
| ) | |
| console.print(header) | |
| success, error_msg = await process_branch(branch_info, original_dir) | |
| result_text = Text() | |
| if success: | |
| result_text.append("β Successfully processed ", style="green bold") | |
| result_text.append(branch_info.name, style="green") | |
| else: | |
| result_text.append("β Failed to process ", style="red bold") | |
| result_text.append(branch_info.name, style="red") | |
| if error_msg: | |
| result_text.append(f"\n\t{error_msg}", style="dim red") | |
| console.print(result_text) | |
| console.print() | |
| progress.update( | |
| task, | |
| advance=1, | |
| description=f"[cyan]Processing branches... ({idx}/{len(branches_to_process)})[/cyan]", | |
| ) | |
| summary = Panel( | |
| f"[bold green]β Processed {len(branches_to_process)} branch(es)[/bold green]" | |
| + (f"\n[dim]Skipped {skipped_count} already up to date[/dim]" if skipped_count > 0 else ""), | |
| title="[green]β Complete[/green]", | |
| border_style="green", | |
| ) | |
| console.print(summary) | |
| finally: | |
| cleanup = Panel( | |
| f"[cyan]Returning to:[/cyan] [bold]{original_branch}[/bold]", | |
| title="[blue]π§Ή Cleanup[/blue]", | |
| border_style="blue", | |
| ) | |
| console.print(cleanup) | |
| os.chdir(original_dir) | |
| try: | |
| run_command_sync("git", "checkout", original_branch, check=False) | |
| except GitError: | |
| console.print("[yellow]β Warning: Could not return to original branch[/yellow]") | |
| if stash_created: | |
| console.print() | |
| await pop_stash() | |
| asyncio.run(async_main()) | |
| if __name__ == "__main__": | |
| app() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment