Skip to content

Commit

Permalink
exp: save: Handle staged changes.
Browse files Browse the repository at this point in the history
Staged changes were causing a merge error.
Warn and unstage instead of erroring, to match exp run behavior.
  • Loading branch information
daavoo committed Nov 28, 2022
1 parent 43e2815 commit 665185b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
10 changes: 9 additions & 1 deletion dvc/repo/experiments/save.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional

from funcy import first

Expand All @@ -23,6 +23,14 @@ def save(
queue = repo.experiments.workspace_queue
logger.debug("Saving workspace in %s", os.getcwd())

staged, _, _ = repo.scm.status(untracked_files="no")
if staged:
logger.warning(
"Your workspace contains staged Git changes which will be "
"unstaged before saving this experiment."
)
repo.scm.reset()

entry = repo.experiments.new(queue=queue, name=name, force=force)
executor = queue.init_executor(repo.experiments, entry)
save_result = executor.save(executor.info, force=force)
Expand Down
12 changes: 11 additions & 1 deletion tests/func/experiments/test_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_exp_save_overwrite_experiment(tmp_dir, dvc, scm, exp_stage):
dvc.experiments.save(name="dummy", force=True)


def test_exp_save_multiple(tmp_dir, dvc, scm, exp_stage):
def test_exp_save_multiple(tmp_dir, dvc, scm):
baseline = scm.get_rev()
for i in range(2):
name = f"exp-{i}"
Expand All @@ -76,3 +76,13 @@ def test_exp_save_after_commit(tmp_dir, dvc, scm, exp_stage):
all_exps = dvc.experiments.ls(all_commits=True)
assert all_exps[baseline] == ["exp-1"]
assert all_exps[new_baseline] == ["exp-2"]


def test_exp_save_with_staged_changes(tmp_dir, dvc, scm):
tmp_dir.gen({"new_file": "new_file"})
scm.add("new_file")

dvc.experiments.save(name="exp")

_, _, unstaged = scm.status()
assert "new_file" in unstaged

0 comments on commit 665185b

Please sign in to comment.