Skip to content

Commit

Permalink
exp remove: rename --all-commits flag and unify the collection of r…
Browse files Browse the repository at this point in the history
…evs (iterative#7155)

fix: iterative#7155
1. rename flags `-A/--all-commits` in `exp remove`
2. add new flag `-n/--num` in `exp remove`
3. unify the revision collection in `exp remove`
4. add unit and func tests for `exp remove`

Co-authored-by: Jorge Orpinel <[email protected]>
  • Loading branch information
karajan1001 and jorgeorpinel committed Mar 9, 2022
1 parent 244637e commit efc021d
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 106 deletions.
31 changes: 23 additions & 8 deletions dvc/commands/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,44 @@

from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.exceptions import InvalidArgumentError

logger = logging.getLogger(__name__)


class CmdExperimentsRemove(CmdBase):
def raise_error_if_all_disabled(self):
if not any(
[
self.args.experiment,
self.args.all_commits,
self.args.rev,
self.args.queue,
]
):
raise InvalidArgumentError(
"Either provide an `experiment` argument, or use the "
"`--rev` or `--all-commits` flag."
)

def run(self):

self.raise_error_if_all_disabled()

self.repo.experiments.remove(
exp_names=self.args.experiment,
all_commits=self.args.all_commits,
rev=self.args.rev,
num=self.args.num,
queue=self.args.queue,
clear_all=self.args.all,
remote=self.args.git_remote,
git_remote=self.args.git_remote,
)

return 0


def add_parser(experiments_subparsers, parent_parser):
from . import add_rev_selection_flags

EXPERIMENTS_REMOVE_HELP = "Remove experiments."
experiments_remove_parser = experiments_subparsers.add_parser(
Expand All @@ -31,15 +51,10 @@ def add_parser(experiments_subparsers, parent_parser):
formatter_class=argparse.RawDescriptionHelpFormatter,
)
remove_group = experiments_remove_parser.add_mutually_exclusive_group()
add_rev_selection_flags(experiments_remove_parser, "Remove", False)
remove_group.add_argument(
"--queue", action="store_true", help="Remove all queued experiments."
)
remove_group.add_argument(
"-A",
"--all",
action="store_true",
help="Remove all committed experiments.",
)
remove_group.add_argument(
"-g",
"--git-remote",
Expand Down
4 changes: 2 additions & 2 deletions dvc/repo/experiments/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List
from typing import Collection, Iterable

from dvc.exceptions import InvalidArgumentError

Expand All @@ -24,7 +24,7 @@ def __init__(

class UnresolvedExpNamesError(InvalidArgumentError):
def __init__(
self, unresolved_list: List[str], *args, git_remote: str = None
self, unresolved_list: Collection[str], *args, git_remote: str = None
):
unresolved_names = ";".join(unresolved_list)
if not git_remote:
Expand Down
218 changes: 141 additions & 77 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
import logging
from typing import List, Optional
from typing import (
TYPE_CHECKING,
Collection,
List,
Mapping,
Optional,
Set,
Union,
)

from dvc.exceptions import InvalidArgumentError
from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm import RevError
from dvc.scm import iter_revs

from .base import ExpRefInfo
from .exceptions import UnresolvedExpNamesError
from .utils import (
exp_refs,
exp_refs_by_baseline,
push_refspec,
remove_exp_refs,
resolve_name,
)

if TYPE_CHECKING:
from dvc.scm import Git

from .utils import exp_refs, push_refspec, remove_exp_refs, resolve_name

logger = logging.getLogger(__name__)

Expand All @@ -15,101 +34,146 @@
@scm_context
def remove(
repo,
exp_names=None,
queue=False,
clear_all=False,
remote=None,
**kwargs,
):
if not any([exp_names, queue, clear_all]):
exp_names: Union[None, str, List[str]] = None,
rev: Optional[str] = None,
all_commits: bool = False,
num: int = 1,
queue: bool = False,
git_remote: Optional[str] = None,
) -> int:
if not any([exp_names, queue, all_commits, rev]):
return 0

removed = 0
if queue:
removed += _clear_stash(repo)
if clear_all:
removed += _clear_all(repo)
if all_commits:
removed += _clear_all_commits(repo.scm, git_remote)
return removed

commit_ref_set: Set[ExpRefInfo] = set()
queued_ref_set: Set[int] = set()
if exp_names:
removed += _remove_exp_by_names(repo, remote, exp_names)
_resolve_exp_by_name(
repo, exp_names, commit_ref_set, queued_ref_set, git_remote
)

if rev:
_resolve_exp_by_baseline(repo, rev, num, commit_ref_set, git_remote)

if commit_ref_set:
removed += _remove_commited_exps(repo.scm, commit_ref_set, git_remote)

if queued_ref_set:
removed += _remove_queued_exps(repo, queued_ref_set)

return removed


def _resolve_exp_by_name(
repo,
exp_names: Union[str, List[str]],
commit_ref_set: Set["ExpRefInfo"],
queued_ref_set: Set[int],
git_remote: Optional[str],
):
remained = set()
if isinstance(exp_names, str):
exp_names = [exp_names]

exp_ref_dict = resolve_name(repo.scm, exp_names, git_remote)
for exp_name, exp_ref in exp_ref_dict.items():
if exp_ref is None:
remained.add(exp_name)
else:
commit_ref_set.add(exp_ref)

if not git_remote:
stash_index_dict = _get_queued_index_by_names(repo, remained)
for exp_name, stash_index in stash_index_dict.items():
if stash_index is not None:
queued_ref_set.add(stash_index)
remained.remove(exp_name)

if remained:
raise UnresolvedExpNamesError(remained)


def _resolve_exp_by_baseline(
repo,
rev: str,
num: int,
commit_ref_set: Set["ExpRefInfo"],
git_remote: Optional[str] = None,
):
rev_dict = iter_revs(repo.scm, [rev], num)
rev_set = set(rev_dict.keys())
ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote)

for _, ref_info_list in ref_info_dict.items():
for ref_info in ref_info_list:
commit_ref_set.add(ref_info)


def _clear_stash(repo):
removed = len(repo.experiments.stash)
repo.experiments.stash.clear()
return removed


def _clear_all(repo):
ref_infos = list(exp_refs(repo.scm))
remove_exp_refs(repo.scm, ref_infos)
def _clear_all_commits(scm, git_remote):
ref_infos = list(exp_refs(scm, git_remote))
_remove_commited_exps(scm, ref_infos, git_remote)
return len(ref_infos)


def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]:
def _get_queued_index_by_names(
repo,
exp_name_set: Set[str],
) -> Mapping[str, Optional[int]]:
from scmrepo.exceptions import RevError as InternalRevError

result = {}
stash_revs = repo.experiments.stash_revs
for _, entry in stash_revs.items():
if entry.name == ref_or_rev:
return entry.stash_index

from dvc.scm import resolve_rev
if entry.name in exp_name_set:
result[entry.name] = entry.stash_index

try:
rev = resolve_rev(repo.scm, ref_or_rev)
if rev in stash_revs:
return stash_revs.get(rev).stash_index
except RevError:
pass
return None
for exp_name in exp_name_set:
if exp_name in result:
continue
try:
rev = repo.scm.resolve_rev(exp_name)
if rev in stash_revs:
result[exp_name] = stash_revs.get(rev).stash_index
except InternalRevError:
result[exp_name] = None
return result


def _remove_commited_exps(
repo, remote: Optional[str], exp_names: List[str]
) -> List[str]:
remain_list = []
remove_list = []
ref_info_dict = resolve_name(repo.scm, exp_names, remote)
for exp_name, ref_info in ref_info_dict.items():
if ref_info:
remove_list.append(ref_info)
else:
remain_list.append(exp_name)
if remove_list:
if not remote:
remove_exp_refs(repo.scm, remove_list)
else:
from dvc.scm import TqdmGit

for ref_info in remove_list:
with TqdmGit(desc="Pushing git refs") as pbar:
push_refspec(
repo.scm,
remote,
None,
str(ref_info),
progress=pbar.update_git,
)
return remain_list


def _remove_queued_exps(repo, refs_or_revs: List[str]) -> List[str]:
remain_list = []
for ref_or_rev in refs_or_revs:
stash_index = _get_exp_stash_index(repo, ref_or_rev)
if stash_index is None:
remain_list.append(ref_or_rev)
else:
repo.experiments.stash.drop(stash_index)
return remain_list


def _remove_exp_by_names(repo, remote, exp_names: List[str]) -> int:
remained = _remove_commited_exps(repo, remote, exp_names)
if not remote:
remained = _remove_queued_exps(repo, remained)
if remained:
raise InvalidArgumentError(
"'{}' is not a valid experiment".format(";".join(remained))
)
return len(exp_names) - len(remained)
scm: "Git", exp_ref_list: Collection["ExpRefInfo"], remote: Optional[str]
) -> int:
if remote:
from dvc.scm import TqdmGit

for ref_info in exp_ref_list:
with TqdmGit(desc="Pushing git refs") as pbar:
push_refspec(
scm,
remote,
None,
str(ref_info),
progress=pbar.update_git,
)
else:
remove_exp_refs(scm, exp_ref_list)
return len(exp_ref_list)


def _remove_queued_exps(repo, indexes: Collection[int]) -> int:
index_list = list(indexes)
index_list.sort(reverse=True)
for index in index_list:
repo.experiments.stash.drop(index)
return len(index_list)
Loading

0 comments on commit efc021d

Please sign in to comment.