Skip to content

Commit

Permalink
exp run: Support --message.
Browse files Browse the repository at this point in the history
Custom commit message to use when committing the experiment.

Part of iterative#8870
  • Loading branch information
daavoo committed May 8, 2023
1 parent 1ba3944 commit c74dc20
Show file tree
Hide file tree
Showing 15 changed files with 97 additions and 12 deletions.
8 changes: 8 additions & 0 deletions dvc/commands/experiments/exec_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def run(self):
log_level=logger.getEffectiveLevel(),
infofile=self.args.infofile,
copy_paths=self.args.copy_paths,
message=self.args.message,
)
return 0

Expand Down Expand Up @@ -47,4 +48,11 @@ def add_parser(experiments_subparsers, parent_parser):
" Only used if `--temp` or `--queue` is specified."
),
)
exec_run_parser.add_argument(
"-M",
"--message",
type=str,
default=None,
help="Custom commit message to use when committing the experiment.",
)
exec_run_parser.set_defaults(func=CmdExecutorRun)
8 changes: 8 additions & 0 deletions dvc/commands/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def run(self):
tmp_dir=self.args.tmp_dir,
machine=self.args.machine,
copy_paths=self.args.copy_paths,
message=self.args.message,
**self._common_kwargs,
)

Expand Down Expand Up @@ -147,3 +148,10 @@ def _add_run_common(parser):
" Only used if `--temp` or `--queue` is specified."
),
)
parser.add_argument(
"-M",
"--message",
type=str,
default=None,
help="Custom commit message to use when committing the experiment.",
)
13 changes: 10 additions & 3 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,17 @@ def reproduce_one(
self,
tmp_dir: bool = False,
copy_paths: Optional[List[str]] = None,
message: Optional[str] = None,
**kwargs,
):
"""Reproduce and checkout a single (standalone) experiment."""
exp_queue: "BaseStashQueue" = (
self.tempdir_queue if tmp_dir else self.workspace_queue
)
self.queue_one(exp_queue, **kwargs)
results = self._reproduce_queue(exp_queue, copy_paths=copy_paths)
results = self._reproduce_queue(
exp_queue, copy_paths=copy_paths, message=message
)
exp_rev = first(results)
if exp_rev is not None:
self._log_reproduced(results, tmp_dir=tmp_dir)
Expand Down Expand Up @@ -349,7 +352,11 @@ def reset_checkpoints(self):

@unlocked_repo
def _reproduce_queue(
self, queue: "BaseStashQueue", copy_paths: Optional[List[str]] = None, **kwargs
self,
queue: "BaseStashQueue",
copy_paths: Optional[List[str]] = None,
message: Optional[str] = None,
**kwargs,
) -> Dict[str, str]:
"""Reproduce queued experiments.
Expand All @@ -360,7 +367,7 @@ def _reproduce_queue(
dict mapping successfully reproduced experiment revs to their
results.
"""
exec_results = queue.reproduce(copy_paths=copy_paths)
exec_results = queue.reproduce(copy_paths=copy_paths, message=message)

results: Dict[str, str] = {}
for _, exp_result in exec_results.items():
Expand Down
8 changes: 7 additions & 1 deletion dvc/repo/experiments/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def reproduce(
log_errors: bool = True,
log_level: Optional[int] = None,
copy_paths: Optional[List[str]] = None,
message: Optional[str] = None,
**kwargs,
) -> "ExecutorResult":
"""Run dvc repro and return the result.
Expand Down Expand Up @@ -570,6 +571,7 @@ def filter_pipeline(stages):
auto_push,
git_remote,
repro_force,
message=message,
)
info.result_hash = exp_hash
info.result_ref = ref
Expand All @@ -590,6 +592,7 @@ def _repro_commit(
auto_push,
git_remote,
repro_force,
message: Optional[str] = None,
) -> Tuple[Optional[str], Optional["ExpRefInfo"], bool]:
is_checkpoint = any(stage.is_checkpoint for stage in stages)
cls.commit(
Expand All @@ -598,6 +601,7 @@ def _repro_commit(
exp_name=info.name,
force=repro_force,
checkpoint=is_checkpoint,
message=message,
)
if auto_push:
cls._auto_push(dvc, dvc.scm, git_remote)
Expand Down Expand Up @@ -760,6 +764,7 @@ def commit(
exp_name: Optional[str] = None,
force: bool = False,
checkpoint: bool = False,
message: Optional[str] = None,
):
"""Commit stages as an experiment and return the commit SHA."""
rev = scm.get_rev()
Expand Down Expand Up @@ -789,7 +794,8 @@ def commit(
logger.debug("Commit to new experiment branch '%s'", branch)

scm.add([], update=True)
scm.commit(f"dvc: commit experiment {exp_hash}", no_verify=True)
message = message or f"dvc: commit experiment {exp_hash}"
scm.commit(message, no_verify=True)
new_rev = scm.get_rev()
if check_conflict:
new_rev = cls._raise_ref_conflict(scm, branch, new_rev, checkpoint)
Expand Down
1 change: 1 addition & 0 deletions dvc/repo/experiments/executor/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def reproduce(
log_errors: bool = True,
log_level: Optional[int] = None,
copy_paths: Optional[List[str]] = None, # noqa: ARG003
message: Optional[str] = None, # noqa: ARG003
**kwargs,
) -> "ExecutorResult":
"""Reproduce an experiment on a remote machine over SSH.
Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def iter_failed(self) -> Generator[QueueDoneResult, None, None]:

@abstractmethod
def reproduce(
self, copy_paths: Optional[List[str]] = None
self, copy_paths: Optional[List[str]] = None, message: Optional[str] = None
) -> Mapping[str, Mapping[str, str]]:
"""Reproduce queued experiments sequentially."""

Expand Down
12 changes: 9 additions & 3 deletions dvc/repo/experiments/queue/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,18 @@ def start_workers(self, count: int) -> int:
return started

def put(
self, *args, copy_paths: Optional[List[str]] = None, **kwargs
self,
*args,
copy_paths: Optional[List[str]] = None,
message: Optional[str] = None,
**kwargs,
) -> QueueEntry:
"""Stash an experiment and add it to the queue."""
with get_exp_rwlock(self.repo, writes=["workspace", CELERY_STASH]):
entry = self._stash_exp(*args, **kwargs)
self.celery.signature(run_exp.s(entry.asdict(), copy_paths=copy_paths)).delay()
self.celery.signature(
run_exp.s(entry.asdict(), copy_paths=copy_paths, message=message)
).delay()
return entry

# NOTE: Queue consumption should not be done directly. Celery worker(s)
Expand Down Expand Up @@ -264,7 +270,7 @@ def iter_failed(self) -> Generator[QueueDoneResult, None, None]:
yield QueueDoneResult(queue_entry, exp_result)

def reproduce(
self, copy_paths: Optional[List[str]] = None
self, copy_paths: Optional[List[str]] = None, message: Optional[str] = None
) -> Mapping[str, Mapping[str, str]]:
raise NotImplementedError

Expand Down
8 changes: 7 additions & 1 deletion dvc/repo/experiments/queue/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ def cleanup_exp(executor: TempDirExecutor, infofile: str) -> None:


@shared_task
def run_exp(entry_dict: Dict[str, Any], copy_paths: Optional[List[str]] = None) -> None:
def run_exp(
entry_dict: Dict[str, Any],
copy_paths: Optional[List[str]] = None,
message: Optional[str] = None,
) -> None:
"""Run a full experiment.
Experiment subtasks are executed inline as one atomic operation.
Expand All @@ -111,6 +115,8 @@ def run_exp(entry_dict: Dict[str, Any], copy_paths: Optional[List[str]] = None)
if copy_paths:
for path in copy_paths:
cmd.extend(["--copy-paths", path])
if message:
cmd.extend(["--message", message])
proc_dict = queue.proc.run_signature(cmd, name=entry.stash_rev)()
collect_exp.s(proc_dict, entry_dict)()
finally:
Expand Down
2 changes: 2 additions & 0 deletions dvc/repo/experiments/queue/tempdir.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def _reproduce_entry(
entry: QueueEntry,
executor: "BaseExecutor",
copy_paths: Optional[List[str]] = None,
message: Optional[str] = None,
**kwargs,
) -> Dict[str, Dict[str, str]]:
from dvc.stage.monitor import CheckpointKilledError
Expand All @@ -113,6 +114,7 @@ def _reproduce_entry(
log_level=logger.getEffectiveLevel(),
log_errors=True,
copy_paths=copy_paths,
message=message,
)
if not exec_result.exp_hash:
raise DvcException(f"Failed to reproduce experiment '{rev[:7]}'")
Expand Down
7 changes: 5 additions & 2 deletions dvc/repo/experiments/queue/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,16 @@ def iter_success(self) -> Generator["QueueDoneResult", None, None]:
raise NotImplementedError

def reproduce(
self, copy_paths: Optional[List[str]] = None
self, copy_paths: Optional[List[str]] = None, message: Optional[str] = None
) -> Dict[str, Dict[str, str]]:
results: Dict[str, Dict[str, str]] = defaultdict(dict)
try:
while True:
entry, executor = self.get()
results.update(
self._reproduce_entry(entry, executor, copy_paths=copy_paths)
self._reproduce_entry(
entry, executor, copy_paths=copy_paths, message=message
)
)
except ExpQueueEmptyError:
pass
Expand Down Expand Up @@ -117,6 +119,7 @@ def _reproduce_entry(
infofile=infofile,
log_level=logger.getEffectiveLevel(),
log_errors=not isinstance(executor, WorkspaceExecutor),
message=kwargs.get("message"),
)
if not exec_result.exp_hash:
raise DvcException(f"Failed to reproduce experiment '{rev[:7]}'")
Expand Down
3 changes: 3 additions & 0 deletions dvc/repo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def run( # noqa: C901, PLR0912
tmp_dir: bool = False,
queue: bool = False,
copy_paths: Optional[Iterable[str]] = None,
message: Optional[str] = None,
**kwargs,
) -> Dict[str, str]:
"""Reproduce the specified targets as an experiment.
Expand Down Expand Up @@ -74,6 +75,7 @@ def run( # noqa: C901, PLR0912
params=path_overrides,
tmp_dir=tmp_dir,
copy_paths=copy_paths,
message=message,
**kwargs,
)

Expand All @@ -96,6 +98,7 @@ def run( # noqa: C901, PLR0912
targets=targets,
params=sweep_overrides,
copy_paths=copy_paths,
message=message,
**kwargs,
)
if sweep_overrides:
Expand Down
16 changes: 16 additions & 0 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,19 @@ def test_mixed_git_dvc_out(tmp_dir, scm, dvc, exp_stage):
assert (tmp_dir / "dir" / "metrics.yaml").exists()
git_fs = scm.get_fs(exp)
assert not git_fs.exists("dir/metrics.yaml")


@pytest.mark.parametrize("tmp", [True, False])
def test_custom_commit_message(tmp_dir, scm, dvc, tmp):
stage = dvc.stage.add(
cmd="echo foo",
name="foo",
)
scm.add_commit(["dvc.yaml"], message="add dvc.yaml")

exp = first(
dvc.experiments.run(
stage.addressing, tmp_dir=tmp, message="custom commit message"
)
)
assert scm.gitpython.repo.commit(exp).message == "custom commit message"
13 changes: 13 additions & 0 deletions tests/func/experiments/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,16 @@ def test_copy_paths_queue(tmp_dir, scm, dvc):
fs = scm.get_fs(exp)
assert not fs.exists("dir")
assert not fs.exists("file")


def test_custom_commit_message_queue(tmp_dir, scm, dvc):
stage = dvc.stage.add(
cmd="echo foo",
name="foo",
)
scm.add_commit(["dvc.yaml"], message="add dvc.yaml")

dvc.experiments.run(stage.addressing, queue=True, message="custom commit message")

exp = first(dvc.experiments.run(run_all=True))
assert scm.gitpython.repo.commit(exp).message == "custom commit message"
7 changes: 6 additions & 1 deletion tests/func/experiments/test_set_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,12 @@ def test_hydra_sweep(
assert patched.call_count == len(expected)
for e in expected:
patched.assert_any_call(
mocker.ANY, params=e, reset=True, targets=None, copy_paths=None
mocker.ANY,
params=e,
reset=True,
targets=None,
copy_paths=None,
message=None,
)


Expand Down
1 change: 1 addition & 0 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def test_experiments_run(dvc, scm, mocker):
"reset": False,
"machine": None,
"copy_paths": [],
"message": None,
}
default_arguments.update(repro_arguments)

Expand Down

0 comments on commit c74dc20

Please sign in to comment.