Skip to content

Commit

Permalink
experiments: fix checkpoint/persist related checkout issues (iterativ…
Browse files Browse the repository at this point in the history
…e#4885)

* dvcfile: add is_lock_file to identify pipeline lock files

* experiments: clean dirty lock files before stashing experiments
  • Loading branch information
pmrowla authored Nov 12, 2020
1 parent eeba567 commit 3daeeaf
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
6 changes: 5 additions & 1 deletion dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,14 @@ def is_valid_filename(path):

def is_dvc_file(path):
return os.path.isfile(path) and (
is_valid_filename(path) or os.path.basename(path) == PIPELINE_LOCK
is_valid_filename(path) or is_lock_file(path)
)


def is_lock_file(path):
return os.path.basename(path) == PIPELINE_LOCK


def check_dvc_filename(path):
if not is_valid_filename(path):
raise StageFileBadNameError(
Expand Down
18 changes: 17 additions & 1 deletion dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from funcy import cached_property, first

from dvc.dvcfile import is_lock_file
from dvc.exceptions import DownloadError, DvcException, UploadError
from dvc.path_info import PathInfo
from dvc.progress import Tqdm
Expand Down Expand Up @@ -284,7 +285,7 @@ def _stash_exp(
if params:
self._update_params(params)

if not self.scm.is_dirty(untracked_files=True) and not allow_unchanged:
if not self._check_dirty() and not allow_unchanged:
# experiment matches original baseline
raise UnchangedExperimentError(rev)

Expand All @@ -298,6 +299,21 @@ def _stash_exp(
self.scm.repo.git.stash("push", "-m", msg)
return self.scm.resolve_rev("stash@{0}")

def _check_dirty(self) -> bool:
# NOTE: dirty DVC lock files must be restored to index state to
# avoid checking out incorrect persist or checkpoint outs
dirty = [diff.a_path for diff in self.scm.repo.index.diff(None)]
to_checkout = [fname for fname in dirty if is_lock_file(fname)]
self.scm.repo.index.checkout(paths=to_checkout, force=True)

untracked = self.scm.repo.untracked_files
to_remove = [fname for fname in untracked if is_lock_file(fname)]
for fname in to_remove:
remove(fname)
return (
len(dirty) - len(to_checkout) + len(untracked) - len(to_remove)
) != 0

def _stash_msg(self, rev, branch=None):
if branch:
return f"{self.STASH_MSG_PREFIX}{rev}:{branch}"
Expand Down

0 comments on commit 3daeeaf

Please sign in to comment.