Skip to content

Commit

Permalink
checkpoints: completely remove checkpoint outs on exp run --reset (i…
Browse files Browse the repository at this point in the history
…terative#5586)

* checkpoints: remove checkpoint outs on --reset

* exp run: no longer prune/reset lockfiles

* update tests

* make --queue imply --reset unless --rev is provided
  • Loading branch information
pmrowla authored Mar 12, 2021
1 parent 11b0581 commit 8c1d46e
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 70 deletions.
5 changes: 5 additions & 0 deletions dvc/command/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,11 @@ class CmdExperimentsRun(CmdRepro):
def run(self):
from dvc.command.metrics import _show_metrics

if self.args.reset and self.args.checkpoint_resume:
raise InvalidArgumentError(
"--reset and --rev are mutually exclusive."
)

if self.args.reset:
logger.info("Any existing checkpoints will be reset and re-run.")

Expand Down
6 changes: 6 additions & 0 deletions dvc/output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def checkout(
relink=False,
filter_info=None,
allow_missing=False,
checkpoint_reset=False,
**kwargs,
):
if not self.use_cache:
Expand All @@ -422,6 +423,11 @@ def checkout(
# backward compatibility
return None

if self.checkpoint and checkpoint_reset:
if self.exists:
self.remove()
return None

added = not self.exists

try:
Expand Down
43 changes: 11 additions & 32 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@ def _stash_exp(
)
self.scm.reset()

self._prune_lockfiles()

# update experiment params from command line
if params:
self._update_params(params)
Expand Down Expand Up @@ -244,34 +242,6 @@ def _stash_exp(

return stash_rev

def _prune_lockfiles(self):
from dvc.dvcfile import is_lock_file

# NOTE: dirty DVC lock files must be restored to index state to
# avoid checking out incorrect persist or checkpoint outs
fs = self.scm.get_fs("HEAD")
lock_files = [
str(fname)
for fname in fs.walk_files(self.scm.root_dir)
if is_lock_file(fname)
]
if lock_files:

self.scm.reset(paths=lock_files)
self.scm.checkout_index(paths=lock_files, force=True)

def _prune_untracked_lockfiles(self):
from dvc.dvcfile import is_lock_file
from dvc.utils.fs import remove

untracked = [
fname
for fname in self.scm.untracked_files()
if is_lock_file(fname)
]
for fname in untracked:
remove(fname)

def _stash_msg(
self,
rev: str,
Expand Down Expand Up @@ -345,9 +315,17 @@ def reproduce_one(
queue: bool = False,
tmp_dir: bool = False,
checkpoint_resume: Optional[str] = None,
reset: bool = False,
**kwargs,
):
"""Reproduce and checkout a single experiment."""
if queue and not checkpoint_resume:
reset = True

if reset:
self.reset_checkpoints()
kwargs["force"] = True

if not (queue or tmp_dir):
staged, _, _ = self.scm.status()
if staged:
Expand All @@ -370,7 +348,9 @@ def reproduce_one(
else:
checkpoint_resume = self._workspace_resume_rev()

stash_rev = self.new(checkpoint_resume=checkpoint_resume, **kwargs)
stash_rev = self.new(
checkpoint_resume=checkpoint_resume, reset=reset, **kwargs
)
if queue:
logger.info(
"Queued experiment '%s' for future execution.", stash_rev[:7],
Expand Down Expand Up @@ -709,7 +689,6 @@ def _workspace_repro(self) -> Mapping[str, str]:
# result in conflict between workspace params and stashed CLI params).
self.scm.reset(hard=True)
with self.scm.detach_head(entry.rev):
self._prune_untracked_lockfiles()
rev = self.stash.pop()
self.scm.set_ref(EXEC_BASELINE, entry.baseline_rev)
if entry.branch:
Expand Down
17 changes: 10 additions & 7 deletions dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,24 +259,27 @@ def filter_pipeline(stages):
"Executor repro with force = '%s'", str(repro_force)
)

# NOTE: for checkpoint experiments we handle persist outs slightly
# differently than normal:
# NOTE: checkpoint outs are handled as a special type of persist
# out:
#
# - checkpoint out may not yet exist if this is the first time this
# experiment has been run, this is not an error condition for
# experiments
# - at the start of a repro run, we need to remove the persist out
# and restore it to its last known (committed) state (which may
# be removed/does not yet exist) so that our executor workspace
# is not polluted with the (persistent) out from an unrelated
# experiment run
# - if experiment was run with --reset, the checkpoint out will be
# removed at the start of the experiment (regardless of any
# dvc.lock entry for the checkpoint out)
# - if run without --reset, the checkpoint out will be checked out
# using any hash present in dvc.lock (or removed if no entry
# exists in dvc.lock)
checkpoint_reset = kwargs.pop("reset", False)
dvc_checkout(
dvc,
targets=targets,
with_deps=targets is not None,
force=True,
quiet=True,
allow_missing=True,
checkpoint_reset=checkpoint_reset,
)

checkpoint_func = partial(
Expand Down
5 changes: 0 additions & 5 deletions dvc/repo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def run(
run_all: bool = False,
jobs: int = 1,
tmp_dir: bool = False,
reset: bool = False,
**kwargs,
) -> dict:
"""Reproduce the specified targets as an experiment.
Expand All @@ -25,10 +24,6 @@ def run(
Returns a dict mapping new experiment SHAs to the results
of `repro` for that experiment.
"""
if reset:
repo.experiments.reset_checkpoints()
kwargs["force"] = True

if run_all:
return repo.experiments.reproduce_queued(jobs=jobs)

Expand Down
4 changes: 0 additions & 4 deletions tests/func/experiments/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ def test_reset_checkpoint(
checkpoint_stage.addressing, name="foo", tmp_dir=not workspace,
)

if workspace:
scm.reset(hard=True)
scm.gitpython.repo.git.clean(force=True)

results = dvc.experiments.run(
checkpoint_stage.addressing,
params=["foo=2"],
Expand Down
22 changes: 0 additions & 22 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,28 +418,6 @@ def test_untracked(tmp_dir, scm, dvc, caplog, workspace):
assert fobj.read().strip() == "foo: 2"


@pytest.mark.parametrize("workspace", [True, False])
def test_dirty_lockfile(tmp_dir, scm, dvc, exp_stage, workspace):
from dvc.dvcfile import LockfileCorruptedError

tmp_dir.gen("dvc.lock", "foo")

with pytest.raises(LockfileCorruptedError):
dvc.reproduce(exp_stage.addressing)

results = dvc.experiments.run(
exp_stage.addressing, params=["foo=2"], tmp_dir=not workspace
)
exp = first(results)

fs = scm.get_fs(exp)
with fs.open(tmp_dir / "metrics.yaml") as fobj:
assert fobj.read().strip() == "foo: 2"

if not workspace:
assert (tmp_dir / "dvc.lock").read_text() == "foo"


def test_packed_args_exists(tmp_dir, scm, dvc, exp_stage, caplog):
from dvc.repo.experiments.executor.base import BaseExecutor

Expand Down

0 comments on commit 8c1d46e

Please sign in to comment.