Skip to content

Commit

Permalink
exp save: cleanup.
Browse files Browse the repository at this point in the history
Remove metrics.

Move tests to separate file.

Fix `experiments.get_exact_name` usage.

Remove unused checkpoint logic.
  • Loading branch information
daavoo committed Nov 28, 2022
1 parent 73b46a5 commit 43e2815
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 29 deletions.
16 changes: 1 addition & 15 deletions dvc/commands/experiments/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,13 @@ def run(self):

if self.args.json:
ui.write_json({"ref": ref})
# fixme: add metrics
else:
name = self.repo.experiments.get_exact_name(ref)
name = self.repo.experiments.get_exact_name([ref])[ref]
ui.write(f"Experiment has been saved as: {name}")
ui.write(
"\nTo promote an experiment to a Git branch run:\n\n"
"\tdvc exp branch <exp> <branch>\n"
)
if self.args.metrics:
from dvc.compare import show_metrics

metrics = self.repo.metrics.show(revs=(ref,))
metrics.pop("workspace", None)
show_metrics(metrics)

return 0

Expand Down Expand Up @@ -63,13 +56,6 @@ def add_parser(experiments_subparsers, parent_parser):
default=False,
help="Show output in JSON format.",
)
save_parser.add_argument(
"-m",
"--metrics",
action="store_true",
default=False,
help="Show metrics for the saved experiment.",
)
save_parser.add_argument(
"-n",
"--name",
Expand Down
22 changes: 9 additions & 13 deletions dvc/repo/experiments/executor/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def cleanup(self, infofile: str):
def save(
cls,
info: "ExecutorInfo",
is_checkpoint: bool = False,
force: bool = False,
) -> ExecutorResult:
from dvc.repo import Repo
Expand All @@ -276,22 +275,19 @@ def save(
dvc.scm,
exp_hash,
exp_name=info.name,
checkpoint=is_checkpoint,
force=force,
)
ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
exp_ref = ExpRefInfo.from_ref(ref) if ref else None
# TODO: research into how untracked files should be handled
if cls.WARN_UNTRACKED:
untracked = dvc.scm.untracked_files()
if untracked:
logger.warning(
"The following untracked files were present in "
"the experiment directory after reproduction but "
"will not be included in experiment commits:\n"
"\t%s",
", ".join(untracked),
)
untracked = dvc.scm.untracked_files()
if untracked:
logger.warning(
"The following untracked files were present in "
"the workspace before saving but "
"will not be included in the experiment commit:\n"
"\t%s",
", ".join(untracked),
)
info.result_hash = exp_hash
info.result_ref = ref
info.result_force = False
Expand Down
2 changes: 1 addition & 1 deletion tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dvc.repo.experiments.queue.base import BaseStashQueue
from dvc.repo.experiments.utils import exp_refs_by_rev
from dvc.scm import resolve_rev
from dvc.stage.exceptions import StageCommitError, StageFileDoesNotExistError
from dvc.stage.exceptions import StageFileDoesNotExistError
from dvc.utils.serialize import PythonFileCorruptedError
from tests.scripts import COPY_SCRIPT

Expand Down
78 changes: 78 additions & 0 deletions tests/func/experiments/test_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from contextlib import nullcontext

import pytest
from funcy import first

from dvc.repo.experiments.exceptions import ExperimentExistsError
from dvc.repo.experiments.utils import exp_refs_by_rev
from dvc.scm import resolve_rev
from dvc.stage.exceptions import StageCommitError


@pytest.mark.parametrize("name", (None, "test"))
def test_exp_save(tmp_dir, dvc, scm, exp_stage, name):
baseline = scm.get_rev()

exp = dvc.experiments.save(name=name)
ref_info = first(exp_refs_by_rev(scm, exp))
assert ref_info and ref_info.baseline_sha == baseline

exp_name = name if name else ref_info.name
assert dvc.experiments.get_exact_name([exp])[exp] == exp_name
assert resolve_rev(scm, exp_name) == exp


@pytest.mark.parametrize(
("force", "expected_raises"),
(
(False, pytest.raises(StageCommitError)),
(True, nullcontext()),
),
)
def test_exp_save_force(tmp_dir, dvc, scm, exp_stage, force, expected_raises):
with open(tmp_dir / "copy.py", "a", encoding="utf-8") as fh:
fh.write("\n# dummy change")

with expected_raises:
dvc.experiments.save(force=force)


def test_exp_save_overwrite_experiment(tmp_dir, dvc, scm, exp_stage):
dvc.experiments.save(name="dummy")

with open(tmp_dir / "copy.py", "a", encoding="utf-8") as fh:
fh.write("\n# dummy change")

with pytest.raises(ExperimentExistsError):
dvc.experiments.save(name="dummy")

dvc.experiments.save(name="dummy", force=True)


def test_exp_save_multiple(tmp_dir, dvc, scm, exp_stage):
baseline = scm.get_rev()
for i in range(2):
name = f"exp-{i}"
tmp_dir.gen({name: f"{name} content"})
dvc.experiments.save(name=name)

assert dvc.experiments.ls()[baseline] == ["exp-0", "exp-1"]

for i in range(2):
scm.reset(hard=True)
name = f"exp-{i}"
dvc.experiments.apply(name)
assert (tmp_dir / name).read_text() == f"{name} content"


def test_exp_save_after_commit(tmp_dir, dvc, scm, exp_stage):
baseline = scm.get_rev()
dvc.experiments.save(name="exp-1")

tmp_dir.scm_gen({"new_file": "new_file"}, commit="new baseline")
new_baseline = scm.get_rev()
dvc.experiments.save(name="exp-2")

all_exps = dvc.experiments.ls(all_commits=True)
assert all_exps[baseline] == ["exp-1"]
assert all_exps[new_baseline] == ["exp-2"]

0 comments on commit 43e2815

Please sign in to comment.