Skip to content

Commit

Permalink
bugfixes, toposort, better logging
Browse files Browse the repository at this point in the history
  • Loading branch information
mumbleskates committed Jun 27, 2024
1 parent e05bd20 commit 7242656
Showing 1 changed file with 105 additions and 41 deletions.
146 changes: 105 additions & 41 deletions unsquash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@

import argparse
from base64 import b64decode
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from getpass import getpass
from itertools import count
from itertools import chain, count
import json
import logging
import re
import sqlite3
import sys
import time
from typing import Callable, Generator, Optional
from typing import Callable, Generator, Iterable, Optional
from urllib.parse import urlparse

from dulwich.client import (get_credentials_from_store, get_transport_and_path,
from dulwich.client import (DEFAULT_GIT_CREDENTIALS_PATHS,
get_credentials_from_store, get_transport_and_path,
GitClient, Urllib3HttpGitClient)
from dulwich.objects import Blob, Commit, Tree
from dulwich.repo import Repo
from github import Github, PullRequest, RateLimitExceededException
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm

__doc__ = """
-- GitHub Unsquasher --
Expand All @@ -45,8 +45,6 @@

INFINITE_PAST = datetime.min.replace(tzinfo=timezone.utc)

log = logging.getLogger(__name__)


def string_to_datetime(s: str) -> datetime:
return datetime_as_utc(datetime.strptime(s, "%Y-%m-%dT%H:%M:%SZ"))
Expand Down Expand Up @@ -142,12 +140,6 @@ def main():
if args.squashed_branch is None and args.squashed_ref is None:
args.squashed_branch = "master" # default to master branch

if args.fetch_squashed_branch:
if args.squashed_branch is None:
print("--squashed_ref is not compatible with "
"--fetch_squashed_branch.")
sys.exit(1)

if args.git_remote_url is None:
if args.squashed_remote is None:
try:
Expand All @@ -166,15 +158,20 @@ def main():

parsed_url = urlparse(args.git_remote_url)
if parsed_url.scheme in ("http", "https"):
username, password = get_credentials_from_store(
parsed_url.scheme,
parsed_url.hostname,
creds = get_credentials_from_store(
parsed_url.scheme.encode(),
parsed_url.hostname.encode(),
fnames=(
None if args.git_credentials_file is None else [
args.git_credentials_file
]
[args.git_credentials_file]
if args.git_credentials_file is not None else
DEFAULT_GIT_CREDENTIALS_PATHS
),
)
username, password = creds or (None, None)
if isinstance(username, bytes):
username = username.decode()
if isinstance(password, bytes):
password = password.decode()
remote_git_client = Urllib3HttpGitClient(
args.git_remote_url,
username=username,
Expand Down Expand Up @@ -223,9 +220,9 @@ def main():
with open(args.token_file, 'r') as f:
token = f.read().strip()

with logging_redirect_tqdm(), GithubCache(db_path=args.pr_db,
github_repo_name=args.github_repo,
github_token=token) as gh_db:
with GithubCache(db_path=args.pr_db,
github_repo_name=args.github_repo,
github_token=token) as gh_db:
rebuild_history(repo=repo, gh_db=gh_db,
remote=remote_git_client,
remote_path=path,
Expand Down Expand Up @@ -639,17 +636,19 @@ def rebuild_history(repo: Repo, remote: GitClient, remote_path: str,

# mapping of {squashed commit id: unsquashed commit id}
unsquashed_mapping = map_unsquashed(repo=repo, heads=map_heads)
# commit_stack will hold all the commits to be unsquashed, the ones with few
# to no unprocessed ancestors at the end to be popped off first.
commit_stack = []
# pending_commits will hold the same values as commit_stack.
# map from squashed commit to merge tip
certain_pr_merges: dict[bytes, bytes] = {}
# pending_commits will hold all the commits to be unsquashed.
pending_commits = set()
# already_processed_tips holds the set of squashed commits we neither need
# to fetch nor process again, as they are already unsquashed or queued up.
already_processed_tips = []
for commit in map_heads:
already_processed_tips.append(commit)
already_processed_tips.append(detect_original_commit(commit))
for commit_id in map_heads:
already_processed_tips.append(commit_id)
commit = repo[commit_id]
original_commit_id = detect_original_commit(commit)
if original_commit_id is not None:
already_processed_tips.append(original_commit_id)

# tips that are being processed this iteration
new_tips = [squashed_head]
Expand All @@ -665,7 +664,6 @@ def rebuild_history(repo: Repo, remote: GitClient, remote_path: str,
walk.commit.id in unsquashed_mapping
):
continue # TODO(widders): this shouldn't happen i think?
commit_stack.append(walk.commit.id)
pending_commits.add(walk.commit.id)
# Check if this commit is a squashed pull request
if len(walk.commit.parents) >= 2:
Expand All @@ -684,6 +682,7 @@ def rebuild_history(repo: Repo, remote: GitClient, remote_path: str,
merge_tip not in unsquashed_mapping
):
# we will try to fetch the pr's contents from the remote
certain_pr_merges[walk.commit.id] = merge_tip
if merge_tip not in new_squash_commits:
new_squash_commits.add(merge_tip)
if merge_tip not in repo:
Expand All @@ -693,8 +692,8 @@ def rebuild_history(repo: Repo, remote: GitClient, remote_path: str,
already_processed_tips.extend(new_tips)

if tips_to_fetch:
log.info(f"attempting to fetch {len(tips_to_fetch)} out of "
f"{len(new_squash_commits)} missing squashed refs")
print(f"attempting to fetch {len(tips_to_fetch)} out of "
f"{len(new_squash_commits)} missing squashed refs")

def determine_wants(sha_dict: dict[bytes, bytes],
_depth: Optional[int] = None) -> list[bytes]:
Expand All @@ -704,28 +703,93 @@ def determine_wants(sha_dict: dict[bytes, bytes],
if commit in tips_to_fetch
]

# if we are on a tty, clean and reset erase the line each time.
# dulwich emits "\r" at the end of most lines which is enough to
# reset the cursor to the start but not enough to clean up any extra
# characters if the progress line gets shorter.
progress_end = "\x1b[1K\r" if sys.stdout.isatty() else "\n"

def progress(msg: bytes):
print(msg.decode().strip(), end=progress_end)

# fetch as many of those tips as possible from the repo
fetch_result = remote.fetch(remote_path, repo,
determine_wants=determine_wants,
progress=log.info)
progress=progress)
# save the refs we wanted to keep
for ref, commit in fetch_result.refs.items():
if (
commit in tips_to_fetch and
commit in repo and
ref not in repo.refs
):
repo.refs[ref] = commit
failed_to_fetch = 0
not_stomped = 0
with tqdm(desc="saving fetched refs", unit="ref",
total=len(tips_to_fetch)) as save_ref_bar:
for ref, commit in fetch_result.refs.items():
if commit not in tips_to_fetch:
continue
save_ref_bar.update(1)
if commit not in repo:
failed_to_fetch += 1
elif ref in repo.refs:
not_stomped += 1
else:
repo.refs[ref] = commit
print(f"{failed_to_fetch} refs failed to fetch")
if not_stomped:
print(f"{not_stomped} refs not saved to avoid stomping "
f"existing refs")

new_tips = [
commit
for commit in new_squash_commits
if commit in repo
]
log.info(f"got {len(new_tips)} new refs")
print(f"got {len(new_tips)} new refs")
if not new_tips:
break # nothing more to do

def unsquashed_parents_of_commit(commit_id: bytes) -> Iterable[bytes]:
merge_tip = certain_pr_merges.get(commit_id)
commit = repo[commit_id]
return chain(commit.parents, (merge_tip,) if merge_tip else ())

# Start performing a topological sort of all the new squashed commits we
# currently have in the repo (so, excluding the ones we failed to fetch;
# those will be fetched via the REST api later, during unsquashing)
unsquashed_commits = set(unsquashed_mapping.values())
# mapping of {commit id: number of child commits not yet in the repo}
dependent = defaultdict(int)
for commit_id in tqdm(pending_commits, desc="preprocessing commit graph",
unit="commit"):
for parent_id in unsquashed_parents_of_commit(commit_id):
if (
parent_id in unsquashed_mapping or
parent_id in unsquashed_commits
):
continue
dependent[parent_id] += 1

# list of pending commits whose parents are all in the repo or already in
# the commit stack
ready = [
commit_id for commit_id in pending_commits
if commit_id not in dependent
]

commit_stack = []
with tqdm(desc="building commit queue", unit="commit",
total=len(pending_commits)) as build_bar:
while ready:
commit_id = ready.pop()
commit_stack.append(commit_id)
for parent_id in unsquashed_parents_of_commit(commit_id):
current_children = dependent[parent_id]
if current_children == 1:
del dependent[parent_id]
ready.append(parent_id)
else:
dependent[parent_id] = current_children - 1
build_bar.update(1)

assert len(dependent) == 0, "DAG violation: topological sort failed"

head_commit_id = None
rewrite_progress = tqdm(total=len(commit_stack),
desc="unsquashing ", unit="commit")
Expand Down

0 comments on commit 7242656

Please sign in to comment.