Skip to content

Commit

Permalink
experiments: pop checkpoint resume from kwargs (iterative#4913)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjawhar authored Nov 20, 2020
1 parent 4cf2f81 commit 153c374
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
12 changes: 9 additions & 3 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,15 +437,21 @@ def reproduce_queued(self, **kwargs):

@scm_locked
def new(
self, *args, branch: Optional[str] = None, **kwargs,
self,
*args,
branch: Optional[str] = None,
checkpoint_resume: Optional[str] = None,
**kwargs,
):
"""Create a new experiment.
Experiment will be reproduced and checked out into the user's
workspace.
"""
if kwargs.get("checkpoint_resume", None) is not None:
return self._resume_checkpoint(*args, **kwargs)
if checkpoint_resume is not None:
return self._resume_checkpoint(
*args, **kwargs, checkpoint_resume=checkpoint_resume
)

if branch:
rev = self.scm.resolve_rev(branch)
Expand Down
31 changes: 15 additions & 16 deletions tests/func/experiments/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,17 @@ def test_new_checkpoint(tmp_dir, scm, dvc, checkpoint_stage, mocker):
).read_text().strip() == "foo: 2"


@pytest.mark.parametrize("last", [True, False])
def test_resume_checkpoint(tmp_dir, scm, dvc, checkpoint_stage, last):
@pytest.mark.parametrize(
"checkpoint_resume", [Experiments.LAST_CHECKPOINT, "foo"]
)
def test_resume_checkpoint(
tmp_dir, scm, dvc, checkpoint_stage, checkpoint_resume
):
with pytest.raises(DvcException):
if last:
dvc.experiments.run(
checkpoint_stage.addressing,
checkpoint_resume=Experiments.LAST_CHECKPOINT,
)
else:
dvc.experiments.run(
checkpoint_stage.addressing, checkpoint_resume="foo"
)
dvc.experiments.run(
checkpoint_stage=checkpoint_stage.addressing,
checkpoint_resume=checkpoint_resume,
)

results = dvc.experiments.run(
checkpoint_stage.addressing, params=["foo=2"]
Expand All @@ -46,12 +45,12 @@ def test_resume_checkpoint(tmp_dir, scm, dvc, checkpoint_stage, last):
checkpoint_stage.addressing, checkpoint_resume="abc1234",
)

if last:
exp_rev = Experiments.LAST_CHECKPOINT
else:
exp_rev = first(results)
if checkpoint_resume != Experiments.LAST_CHECKPOINT:
checkpoint_resume = first(results)

dvc.experiments.run(checkpoint_stage.addressing, checkpoint_resume=exp_rev)
dvc.experiments.run(
checkpoint_stage.addressing, checkpoint_resume=checkpoint_resume
)

assert (tmp_dir / "foo").read_text() == "10"
assert (
Expand Down

0 comments on commit 153c374

Please sign in to comment.