Skip to content

Commit

Permalink
Remove a special queued experiments (iterative#6393)
Browse files Browse the repository at this point in the history
* Remove a special queued experiments

fix iterative#6157
1. dvc exp remove not accept queued experiments name
2. add some tests for this feature

* Extract tests and add revision support

1. Extract remove experiments to a new file.
2. revision can be used to remove special queued experiment

* Accept shortened revisions

* Split removing committed and queued exp functions

* Api name change

* Return to the old API name

* shorten some functions

* Error message change

* Better test cases, more corner case

* Still raise exception in a mixed case
  • Loading branch information
karajan1001 authored Aug 13, 2021
1 parent 4c957bf commit de97a49
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 35 deletions.
78 changes: 58 additions & 20 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
from typing import List, Optional

from dvc.exceptions import InvalidArgumentError
from dvc.repo import locked
from dvc.repo.scm_context import scm_context
from dvc.scm.base import RevError

from .base import EXPS_NAMESPACE, ExpRefInfo
from .utils import exp_refs_by_name, remove_exp_refs
Expand All @@ -21,32 +23,43 @@ def remove(repo, exp_names=None, queue=False, **kwargs):
removed += len(repo.experiments.stash)
repo.experiments.stash.clear()
if exp_names:
ref_infos = list(_get_exp_refs(repo, exp_names))
remove_exp_refs(repo.scm, ref_infos)
removed += len(ref_infos)
remained = _remove_commited_exps(repo, exp_names)
remained = _remove_queued_exps(repo, remained)
if remained:
raise InvalidArgumentError(
"'{}' is not a valid experiment".format(";".join(remained))
)
removed += len(exp_names) - len(remained)
return removed


def _get_exp_refs(repo, exp_names):
cur_rev = repo.scm.get_rev()
for name in exp_names:
if name.startswith(EXPS_NAMESPACE):
if not repo.scm.get_ref(name):
raise InvalidArgumentError(
f"'{name}' is not a valid experiment name"
)
yield ExpRefInfo.from_ref(name)
else:
def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]:
stash_revs = repo.experiments.stash_revs
for _, ref_info in stash_revs.items():
if ref_info.name == ref_or_rev:
return ref_info.index
try:
rev = repo.scm.resolve_rev(ref_or_rev)
if rev in stash_revs:
return stash_revs.get(rev).index
except RevError:
pass
return None

exp_refs = list(exp_refs_by_name(repo.scm, name))
if not exp_refs:
raise InvalidArgumentError(
f"'{name}' is not a valid experiment name"
)
yield _get_ref(exp_refs, name, cur_rev)

def _get_exp_ref(repo, exp_name: str) -> Optional[ExpRefInfo]:
cur_rev = repo.scm.get_rev()
if exp_name.startswith(EXPS_NAMESPACE):
if repo.scm.get_ref(exp_name):
return ExpRefInfo.from_ref(exp_name)
else:
exp_refs = list(exp_refs_by_name(repo.scm, exp_name))
if exp_refs:
return _get_ref(exp_refs, exp_name, cur_rev)
return None


def _get_ref(ref_infos, name, cur_rev):
def _get_ref(ref_infos, name, cur_rev) -> Optional[ExpRefInfo]:
if len(ref_infos) > 1:
for info in ref_infos:
if info.baseline_sha == cur_rev:
Expand All @@ -61,3 +74,28 @@ def _get_ref(ref_infos, name, cur_rev):
msg.extend([f"\t{info}" for info in ref_infos])
raise InvalidArgumentError("\n".join(msg))
return ref_infos[0]


def _remove_commited_exps(repo, refs: List[str]) -> List[str]:
remain_list = []
remove_list = []
for ref in refs:
ref_info = _get_exp_ref(repo, ref)
if ref_info:
remove_list.append(ref_info)
else:
remain_list.append(ref)
if remove_list:
remove_exp_refs(repo.scm, remove_list)
return remain_list


def _remove_queued_exps(repo, refs_or_revs: List[str]) -> List[str]:
remain_list = []
for ref_or_rev in refs_or_revs:
stash_index = _get_exp_stash_index(repo, ref_or_rev)
if stash_index is None:
remain_list.append(ref_or_rev)
else:
repo.experiments.stash.drop(stash_index)
return remain_list
15 changes: 0 additions & 15 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,21 +576,6 @@ def test_run_metrics(tmp_dir, scm, dvc, exp_stage, mocker):
assert show_mock.called_once()


def test_remove(tmp_dir, scm, dvc, exp_stage):
results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"])
exp = first(results)
ref_info = first(exp_refs_by_rev(scm, exp))
dvc.experiments.run(exp_stage.addressing, params=["foo=3"], queue=True)

removed = dvc.experiments.remove([str(ref_info)])
assert removed == 1
assert scm.get_ref(str(ref_info)) is None

removed = dvc.experiments.remove(queue=True)
assert removed == 1
assert len(dvc.experiments.stash) == 0


def test_checkout_targets_deps(tmp_dir, scm, dvc, exp_stage):
from dvc.utils.fs import remove

Expand Down
74 changes: 74 additions & 0 deletions tests/func/experiments/test_remove.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest
from funcy import first

from dvc.exceptions import InvalidArgumentError
from dvc.repo.experiments.utils import exp_refs_by_rev


def test_remove_experiments_by_ref(tmp_dir, scm, dvc, exp_stage, caplog):
queue_length = 3
ref_list = []

for i in range(queue_length):
results = dvc.experiments.run(
exp_stage.addressing, params=[f"foo={i}"]
)
ref_info = first(exp_refs_by_rev(scm, first(results)))
ref_list.append(str(ref_info))

with pytest.raises(InvalidArgumentError):
assert dvc.experiments.remove(ref_list[:2] + ["non-exist"])
assert scm.get_ref(str(ref_list[0])) is None
assert scm.get_ref(str(ref_list[1])) is None
assert scm.get_ref(str(ref_list[2])) is not None


def test_remove_all_queued_experiments(tmp_dir, scm, dvc, exp_stage):
queue_length = 3

for i in range(queue_length):
dvc.experiments.run(
exp_stage.addressing, params=[f"foo={i}"], queue=True
)

results = dvc.experiments.run(
exp_stage.addressing, params=[f"foo={queue_length}"]
)
ref_info = first(exp_refs_by_rev(scm, first(results)))

assert len(dvc.experiments.stash) == queue_length
assert dvc.experiments.remove(queue=True) == queue_length
assert len(dvc.experiments.stash) == 0
assert scm.get_ref(str(ref_info)) is not None


def test_remove_special_queued_experiments(tmp_dir, scm, dvc, exp_stage):
results = dvc.experiments.run(
exp_stage.addressing, params=["foo=1"], queue=True, name="queue1"
)
rev1 = first(results)
results = dvc.experiments.run(
exp_stage.addressing, params=["foo=2"], queue=True, name="queue2"
)
rev2 = first(results)
results = dvc.experiments.run(
exp_stage.addressing, params=["foo=3"], queue=True, name="queue3"
)
rev3 = first(results)
results = dvc.experiments.run(exp_stage.addressing, params=["foo=4"])
ref_info1 = first(exp_refs_by_rev(scm, first(results)))
results = dvc.experiments.run(exp_stage.addressing, params=["foo=5"])
ref_info2 = first(exp_refs_by_rev(scm, first(results)))

assert rev1 in dvc.experiments.stash_revs
assert rev2 in dvc.experiments.stash_revs
assert rev3 in dvc.experiments.stash_revs
assert scm.get_ref(str(ref_info1)) is not None
assert scm.get_ref(str(ref_info2)) is not None

assert dvc.experiments.remove(["queue1", rev2[:5], str(ref_info1)]) == 3
assert rev1 not in dvc.experiments.stash_revs
assert rev2 not in dvc.experiments.stash_revs
assert rev3 in dvc.experiments.stash_revs
assert scm.get_ref(str(ref_info1)) is None
assert scm.get_ref(str(ref_info2)) is not None

0 comments on commit de97a49

Please sign in to comment.