diff --git a/unsquash.py b/unsquash.py index 6bd1612..a0b6331 100644 --- a/unsquash.py +++ b/unsquash.py @@ -2,25 +2,25 @@ import argparse from base64 import b64decode +from collections import defaultdict from datetime import datetime, timedelta, timezone from getpass import getpass -from itertools import count +from itertools import chain, count import json -import logging import re import sqlite3 import sys import time -from typing import Callable, Generator, Optional +from typing import Callable, Generator, Iterable, Optional from urllib.parse import urlparse -from dulwich.client import (get_credentials_from_store, get_transport_and_path, +from dulwich.client import (DEFAULT_GIT_CREDENTIALS_PATHS, + get_credentials_from_store, get_transport_and_path, GitClient, Urllib3HttpGitClient) from dulwich.objects import Blob, Commit, Tree from dulwich.repo import Repo from github import Github, PullRequest, RateLimitExceededException from tqdm import tqdm -from tqdm.contrib.logging import logging_redirect_tqdm __doc__ = """ -- GitHub Unsquasher -- @@ -45,8 +45,6 @@ INFINITE_PAST = datetime.min.replace(tzinfo=timezone.utc) -log = logging.getLogger(__name__) - def string_to_datetime(s: str) -> datetime: return datetime_as_utc(datetime.strptime(s, "%Y-%m-%dT%H:%M:%SZ")) @@ -142,12 +140,6 @@ def main(): if args.squashed_branch is None and args.squashed_ref is None: args.squashed_branch = "master" # default to master branch - if args.fetch_squashed_branch: - if args.squashed_branch is None: - print("--squashed_ref is not compatible with " - "--fetch_squashed_branch.") - sys.exit(1) - if args.git_remote_url is None: if args.squashed_remote is None: try: @@ -166,15 +158,20 @@ def main(): parsed_url = urlparse(args.git_remote_url) if parsed_url.scheme in ("http", "https"): - username, password = get_credentials_from_store( - parsed_url.scheme, - parsed_url.hostname, + creds = get_credentials_from_store( + parsed_url.scheme.encode(), + parsed_url.hostname.encode(), fnames=( - None if args.git_credentials_file is None else [ - args.git_credentials_file - ] + [args.git_credentials_file] + if args.git_credentials_file is not None else + DEFAULT_GIT_CREDENTIALS_PATHS ), ) + username, password = creds or (None, None) + if isinstance(username, bytes): + username = username.decode() + if isinstance(password, bytes): + password = password.decode() remote_git_client = Urllib3HttpGitClient( args.git_remote_url, username=username, @@ -223,9 +220,9 @@ def main(): with open(args.token_file, 'r') as f: token = f.read().strip() - with logging_redirect_tqdm(), GithubCache(db_path=args.pr_db, - github_repo_name=args.github_repo, - github_token=token) as gh_db: + with GithubCache(db_path=args.pr_db, + github_repo_name=args.github_repo, + github_token=token) as gh_db: rebuild_history(repo=repo, gh_db=gh_db, remote=remote_git_client, remote_path=path, @@ -639,17 +636,19 @@ def rebuild_history(repo: Repo, remote: GitClient, remote_path: str, # mapping of {squashed commit id: unsquashed commit id} unsquashed_mapping = map_unsquashed(repo=repo, heads=map_heads) - # commit_stack will hold all the commits to be unsquashed, the ones with few - # to no unprocessed ancestors at the end to be popped off first. - commit_stack = [] - # pending_commits will hold the same values as commit_stack. + # map from squashed commit to merge tip + certain_pr_merges: dict[bytes, bytes] = {} + # pending_commits will hold all the commits to be unsquashed. pending_commits = set() # already_processed_tips holds the set of squashed commits we neither need # to fetch nor process again, as they are already unsquashed or queued up. already_processed_tips = [] - for commit in map_heads: - already_processed_tips.append(commit) - already_processed_tips.append(detect_original_commit(commit)) + for commit_id in map_heads: + already_processed_tips.append(commit_id) + commit = repo[commit_id] + original_commit_id = detect_original_commit(commit) + if original_commit_id is not None: + already_processed_tips.append(original_commit_id) # tips that are being processed this iteration new_tips = [squashed_head] @@ -665,7 +664,6 @@ def rebuild_history(repo: Repo, remote: GitClient, remote_path: str, walk.commit.id in unsquashed_mapping ): continue # TODO(widders): this shouldn't happen i think? - commit_stack.append(walk.commit.id) pending_commits.add(walk.commit.id) # Check if this commit is a squashed pull request if len(walk.commit.parents) >= 2: @@ -684,6 +682,7 @@ def rebuild_history(repo: Repo, remote: GitClient, remote_path: str, merge_tip not in unsquashed_mapping ): # we will try to fetch the pr's contents from the remote + certain_pr_merges[walk.commit.id] = merge_tip if merge_tip not in new_squash_commits: new_squash_commits.add(merge_tip) if merge_tip not in repo: @@ -693,8 +692,8 @@ def rebuild_history(repo: Repo, remote: GitClient, remote_path: str, already_processed_tips.extend(new_tips) if tips_to_fetch: - log.info(f"attempting to fetch {len(tips_to_fetch)} out of " - f"{len(new_squash_commits)} missing squashed refs") + print(f"attempting to fetch {len(tips_to_fetch)} out of " + f"{len(new_squash_commits)} missing squashed refs") def determine_wants(sha_dict: dict[bytes, bytes], _depth: Optional[int] = None) -> list[bytes]: @@ -704,28 +703,93 @@ def determine_wants(sha_dict: dict[bytes, bytes], if commit in tips_to_fetch ] + # if we are on a tty, clean and reset erase the line each time. + # dulwich emits "\r" at the end of most lines which is enough to + # reset the cursor to the start but not enough to clean up any extra + # characters if the progress line gets shorter. + progress_end = "\x1b[1K\r" if sys.stdout.isatty() else "\n" + + def progress(msg: bytes): + print(msg.decode().strip(), end=progress_end) + # fetch as many of those tips as possible from the repo fetch_result = remote.fetch(remote_path, repo, determine_wants=determine_wants, - progress=log.info) + progress=progress) # save the refs we wanted to keep - for ref, commit in fetch_result.refs.items(): - if ( - commit in tips_to_fetch and - commit in repo and - ref not in repo.refs - ): - repo.refs[ref] = commit + failed_to_fetch = 0 + not_stomped = 0 + with tqdm(desc="saving fetched refs", unit="ref", + total=len(tips_to_fetch)) as save_ref_bar: + for ref, commit in fetch_result.refs.items(): + if commit not in tips_to_fetch: + continue + save_ref_bar.update(1) + if commit not in repo: + failed_to_fetch += 1 + elif ref in repo.refs: + not_stomped += 1 + else: + repo.refs[ref] = commit + print(f"{failed_to_fetch} refs failed to fetch") + if not_stomped: + print(f"{not_stomped} refs not saved to avoid stomping " + f"existing refs") new_tips = [ commit for commit in new_squash_commits if commit in repo ] - log.info(f"got {len(new_tips)} new refs") + print(f"got {len(new_tips)} new refs") if not new_tips: break # nothing more to do + def unsquashed_parents_of_commit(commit_id: bytes) -> Iterable[bytes]: + merge_tip = certain_pr_merges.get(commit_id) + commit = repo[commit_id] + return chain(commit.parents, (merge_tip,) if merge_tip else ()) + + # Start performing a topological sort of all the new squashed commits we + # currently have in the repo (so, excluding the ones we failed to fetch; + # those will be fetched via the REST api later, during unsquashing) + unsquashed_commits = set(unsquashed_mapping.values()) + # mapping of {commit id: number of child commits not yet in the repo} + dependent = defaultdict(int) + for commit_id in tqdm(pending_commits, desc="preprocessing commit graph", + unit="commit"): + for parent_id in unsquashed_parents_of_commit(commit_id): + if ( + parent_id in unsquashed_mapping or + parent_id in unsquashed_commits + ): + continue + dependent[parent_id] += 1 + + # list of pending commits whose parents are all in the repo or already in + # the commit stack + ready = [ + commit_id for commit_id in pending_commits + if commit_id not in dependent + ] + + commit_stack = [] + with tqdm(desc="building commit queue", unit="commit", + total=len(pending_commits)) as build_bar: + while ready: + commit_id = ready.pop() + commit_stack.append(commit_id) + for parent_id in unsquashed_parents_of_commit(commit_id): + current_children = dependent[parent_id] + if current_children == 1: + del dependent[parent_id] + ready.append(parent_id) + else: + dependent[parent_id] = current_children - 1 + build_bar.update(1) + + assert len(dependent) == 0, "DAG violation: topological sort failed" + head_commit_id = None rewrite_progress = tqdm(total=len(commit_stack), desc="unsquashing ", unit="commit")