Skip to content

Commit

Permalink
exp save: initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
dtrifiro authored and daavoo committed Nov 28, 2022
1 parent 8124396 commit 73b46a5
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 5 deletions.
2 changes: 2 additions & 0 deletions dvc/commands/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
queue_worker,
remove,
run,
save,
show,
)

Expand All @@ -34,6 +35,7 @@
queue_worker,
remove,
run,
save,
show,
]

Expand Down
83 changes: 83 additions & 0 deletions dvc/commands/experiments/save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import argparse
import logging

from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.exceptions import DvcException
from dvc.ui import ui

logger = logging.getLogger(__name__)


class CmdExperimentsSave(CmdBase):
def run(self):

try:
ref = self.repo.experiments.save(
name=self.args.name, force=self.args.force
)
except DvcException:
logger.exception("failed to save experiment")
return 1

if self.args.json:
ui.write_json({"ref": ref})
# fixme: add metrics
else:
name = self.repo.experiments.get_exact_name(ref)
ui.write(f"Experiment has been saved as: {name}")
ui.write(
"\nTo promote an experiment to a Git branch run:\n\n"
"\tdvc exp branch <exp> <branch>\n"
)
if self.args.metrics:
from dvc.compare import show_metrics

metrics = self.repo.metrics.show(revs=(ref,))
metrics.pop("workspace", None)
show_metrics(metrics)

return 0


def add_parser(experiments_subparsers, parent_parser):
EXPERIMENTS_SAVE_HELP = "Save current workspace as a dvc experiment."
save_parser = experiments_subparsers.add_parser(
"save",
parents=[parent_parser],
description=append_doc_link(EXPERIMENTS_SAVE_HELP, "exp/save"),
help=EXPERIMENTS_SAVE_HELP,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
save_parser.add_argument(
"-f",
"--force",
action="store_true",
default=False,
help="Save even if hash value for dependencies/outputs changed.",
)
save_parser.add_argument(
"--json",
"--show-json",
action="store_true",
default=False,
help="Show output in JSON format.",
)
save_parser.add_argument(
"-m",
"--metrics",
action="store_true",
default=False,
help="Show metrics for the saved experiment.",
)
save_parser.add_argument(
"-n",
"--name",
default=None,
help=(
"Human-readable experiment name. If not specified, a name will "
"be auto-generated."
),
metavar="<name>",
)
save_parser.set_defaults(func=CmdExperimentsSave)
5 changes: 5 additions & 0 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,11 @@ def run(self, *args, **kwargs):

return run(self.repo, *args, **kwargs)

def save(self, *args, **kwargs):
from dvc.repo.experiments.save import save

return save(self.repo, *args, **kwargs)

def gc(self, *args, **kwargs):
from dvc.repo.experiments.gc import gc

Expand Down
61 changes: 59 additions & 2 deletions dvc/repo/experiments/executor/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from shortuuid import uuid

from dvc.lock import LockError
from dvc.exceptions import DvcException
from dvc.scm import SCM, GitMergeError
from dvc.utils.fs import makedirs, remove

Expand All @@ -21,17 +22,18 @@
EXEC_MERGE,
EXEC_NAMESPACE,
EXPS_TEMP,
ExpRefInfo,
)
from ..utils import EXEC_TMP_DIR, get_exp_rwlock
from .base import BaseExecutor, TaskStatus
from .base import BaseExecutor, TaskStatus, ExecutorResult

if TYPE_CHECKING:
from scmrepo.git import Git

from dvc.repo import Repo

from ..refs import ExpRefInfo
from ..stash import ExpStashEntry
from .base import ExecutorInfo

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -247,3 +249,58 @@ def cleanup(self, infofile: str):
checkpoint = self.scm.get_ref(EXEC_CHECKPOINT)
if checkpoint and checkpoint != self._orig_checkpoint:
self.scm.set_ref(EXEC_APPLY, checkpoint)

@classmethod
def save(
cls,
info: "ExecutorInfo",
is_checkpoint: bool = False,
force: bool = False,
) -> ExecutorResult:
from dvc.repo import Repo

exp_hash: Optional[str] = None
exp_ref: Optional[ExpRefInfo] = None

dvc = Repo(os.path.join(info.root_dir, info.dvc_dir))
old_cwd = os.getcwd()
if info.wdir:
os.chdir(os.path.join(dvc.scm.root_dir, info.wdir))
else:
os.chdir(dvc.root_dir)

try:
stages = dvc.commit([], force=force)
exp_hash = cls.hash_exp(stages)
cls.commit(
dvc.scm,
exp_hash,
exp_name=info.name,
checkpoint=is_checkpoint,
force=force,
)
ref: Optional[str] = dvc.scm.get_ref(EXEC_BRANCH, follow=False)
exp_ref = ExpRefInfo.from_ref(ref) if ref else None
# TODO: research into how untracked files should be handled
if cls.WARN_UNTRACKED:
untracked = dvc.scm.untracked_files()
if untracked:
logger.warning(
"The following untracked files were present in "
"the experiment directory after reproduction but "
"will not be included in experiment commits:\n"
"\t%s",
", ".join(untracked),
)
info.result_hash = exp_hash
info.result_ref = ref
info.result_force = False
info.status = TaskStatus.SUCCESS
except DvcException:
info.status = TaskStatus.FAILED
raise
finally:
dvc.close()
os.chdir(old_cwd)

return ExecutorResult(ref, exp_ref, info.result_force)
32 changes: 32 additions & 0 deletions dvc/repo/experiments/save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging
import os
from typing import TYPE_CHECKING, Optional

from funcy import first

if TYPE_CHECKING:
from dvc.repo import Repo


logger = logging.getLogger(__name__)


def save(
repo: "Repo",
name: Optional[str] = None,
force: bool = False,
) -> Optional[str]:
"""Save the current workspace status as an experiment.
Returns the saved experiment's SHAs.
"""
queue = repo.experiments.workspace_queue
logger.debug("Saving workspace in %s", os.getcwd())

entry = repo.experiments.new(queue=queue, name=name, force=force)
executor = queue.init_executor(repo.experiments, entry)
save_result = executor.save(executor.info, force=force)
result = queue.collect_executor(repo.experiments, executor, save_result)

exp_rev = first(result)
return exp_rev
5 changes: 2 additions & 3 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

from dvc.dvcfile import PIPELINE_FILE
from dvc.exceptions import ReproductionError
from dvc.repo.experiments.exceptions import ExperimentExistsError
from dvc.repo.experiments.queue.base import BaseStashQueue
from dvc.repo.experiments.utils import exp_refs_by_rev
from dvc.scm import resolve_rev
from dvc.stage.exceptions import StageFileDoesNotExistError
from dvc.stage.exceptions import StageCommitError, StageFileDoesNotExistError
from dvc.utils.serialize import PythonFileCorruptedError
from tests.scripts import COPY_SCRIPT

Expand Down Expand Up @@ -44,8 +45,6 @@ def test_new_simple(tmp_dir, scm, dvc, exp_stage, mocker, name, workspace):


def test_experiment_exists(tmp_dir, scm, dvc, exp_stage, mocker, workspace):
from dvc.repo.experiments.exceptions import ExperimentExistsError

dvc.experiments.run(
exp_stage.addressing,
name="foo",
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dvc.commands.experiments.push import CmdExperimentsPush
from dvc.commands.experiments.remove import CmdExperimentsRemove
from dvc.commands.experiments.run import CmdExperimentsRun
from dvc.commands.experiments.save import CmdExperimentsSave
from dvc.commands.experiments.show import CmdExperimentsShow, show_experiments
from dvc.exceptions import InvalidArgumentError
from dvc.repo import Repo
Expand Down Expand Up @@ -934,3 +935,15 @@ def test_show_experiments_pcp(tmp_dir, mocker):

assert kwargs["output_path"] == str(tmp_dir / "dvc_plots" / "index.html")
assert kwargs["color_by"] == "Experiment"


def test_experiments_save(dvc, scm, mocker):
cli_args = parse_args(["exp", "save", "--name", "exp-name", "--force"])
assert cli_args.func == CmdExperimentsSave

cmd = cli_args.func(cli_args)
m = mocker.patch("dvc.repo.experiments.save.save", return_value="acabb")

assert cmd.run() == 0

m.assert_called_once_with(cmd.repo, name="exp-name", force=True)

0 comments on commit 73b46a5

Please sign in to comment.