Skip to content

Commit

Permalink
Remove all of the experiments in the workspace (iterative#6394)
Browse files Browse the repository at this point in the history
* Remove all of the experiments in the workspace

fix iterative#5676
1. remove all the experiments in workspace
2. add tests for it

* Only remove baseline ones

* Add a new arguments --all

* Update dvc/repo/experiments/remove.py

Co-authored-by: Peter Rowlands (변기호) <[email protected]>

* Rename and shortcut for these commands

* Solve the issue from removing a changing list.

* Update dvc/repo/experiments/remove.py

Co-authored-by: Peter Rowlands (변기호) <[email protected]>

* Update dvc/repo/experiments/remove.py

Co-authored-by: Peter Rowlands (변기호) <[email protected]>

* delete workspace option

Co-authored-by: Peter Rowlands (변기호) <[email protected]>
  • Loading branch information
karajan1001 and pmrowla authored Aug 19, 2021
1 parent 4b9a411 commit 57899ee
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 14 deletions.
13 changes: 11 additions & 2 deletions dvc/command/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,9 @@ class CmdExperimentsRemove(CmdBase):
def run(self):

self.repo.experiments.remove(
exp_names=self.args.experiment, queue=self.args.queue
exp_names=self.args.experiment,
queue=self.args.queue,
clear_all=self.args.all,
)

return 0
Expand Down Expand Up @@ -1237,9 +1239,16 @@ def add_parser(subparsers, parent_parser):
help=EXPERIMENTS_REMOVE_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
experiments_remove_parser.add_argument(
remove_group = experiments_remove_parser.add_mutually_exclusive_group()
remove_group.add_argument(
"--queue", action="store_true", help="Remove all queued experiments."
)
remove_group.add_argument(
"-A",
"--all",
action="store_true",
help="Remove all committed experiments.",
)
experiments_remove_parser.add_argument(
"experiment",
nargs="*",
Expand Down
36 changes: 28 additions & 8 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,29 @@
from dvc.scm.base import RevError

from .base import EXPS_NAMESPACE, ExpRefInfo
from .utils import exp_refs_by_name, remove_exp_refs
from .utils import exp_refs, exp_refs_by_name, remove_exp_refs

logger = logging.getLogger(__name__)


@locked
@scm_context
def remove(repo, exp_names=None, queue=False, **kwargs):
if not exp_names and not queue:
def remove(
repo,
exp_names=None,
queue=False,
clear_all=False,
**kwargs,
):
if not any([exp_names, queue, clear_all]):
return 0

removed = 0
if queue:
removed += len(repo.experiments.stash)
repo.experiments.stash.clear()
removed += _clear_stash(repo)
if clear_all:
removed += _clear_all(repo)

if exp_names:
remained = _remove_commited_exps(repo, exp_names)
remained = _remove_queued_exps(repo, remained)
Expand All @@ -33,6 +41,18 @@ def remove(repo, exp_names=None, queue=False, **kwargs):
return removed


def _clear_stash(repo):
removed = len(repo.experiments.stash)
repo.experiments.stash.clear()
return removed


def _clear_all(repo):
ref_infos = list(exp_refs(repo.scm))
remove_exp_refs(repo.scm, ref_infos)
return len(ref_infos)


def _get_exp_stash_index(repo, ref_or_rev: str) -> Optional[int]:
stash_revs = repo.experiments.stash_revs
for _, ref_info in stash_revs.items():
Expand All @@ -53,9 +73,9 @@ def _get_exp_ref(repo, exp_name: str) -> Optional[ExpRefInfo]:
if repo.scm.get_ref(exp_name):
return ExpRefInfo.from_ref(exp_name)
else:
exp_refs = list(exp_refs_by_name(repo.scm, exp_name))
if exp_refs:
return _get_ref(exp_refs, exp_name, cur_rev)
exp_ref_list = list(exp_refs_by_name(repo.scm, exp_name))
if exp_ref_list:
return _get_ref(exp_ref_list, exp_name, cur_rev)
return None


Expand Down
18 changes: 18 additions & 0 deletions tests/func/experiments/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,21 @@ def test_remove_special_queued_experiments(tmp_dir, scm, dvc, exp_stage):
assert rev3 in dvc.experiments.stash_revs
assert scm.get_ref(str(ref_info1)) is None
assert scm.get_ref(str(ref_info2)) is not None


def test_remove_all(tmp_dir, scm, dvc, exp_stage):
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"])
ref_info1 = first(exp_refs_by_rev(scm, first(results)))
dvc.experiments.run(exp_stage.addressing, params=["foo=2"], queue=True)
scm.add(["dvc.yaml", "dvc.lock", "copy.py", "params.yaml", "metrics.yaml"])
scm.commit("update baseline")

results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"])
ref_info2 = first(exp_refs_by_rev(scm, first(results)))
dvc.experiments.run(exp_stage.addressing, params=["foo=4"], queue=True)

removed = dvc.experiments.remove(clear_all=True)
assert removed == 2
assert len(dvc.experiments.stash) == 2
assert scm.get_ref(str(ref_info2)) is None
assert scm.get_ref(str(ref_info1)) is None
20 changes: 16 additions & 4 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,25 @@ def test_experiments_pull(dvc, scm, mocker):
)


def test_experiments_remove(dvc, scm, mocker):
cli_args = parse_args(["experiments", "remove", "--queue"])
@pytest.mark.parametrize(
"queue,clear_all",
[(True, False), (False, True)],
)
def test_experiments_remove(dvc, scm, mocker, queue, clear_all):
if queue:
args = "--queue"
if clear_all:
args = "--all"
cli_args = parse_args(["experiments", "remove", args])
assert cli_args.func == CmdExperimentsRemove

cmd = cli_args.func(cli_args)
m = mocker.patch("dvc.repo.experiments.remove.remove", return_value={})

assert cmd.run() == 0

m.assert_called_once_with(cmd.repo, exp_names=[], queue=True)
m.assert_called_once_with(
cmd.repo,
exp_names=[],
queue=queue,
clear_all=clear_all,
)

0 comments on commit 57899ee

Please sign in to comment.