Skip to content

Commit

Permalink
exp: init: Use persist: true if checkpoint.
Browse files Browse the repository at this point in the history
  • Loading branch information
daavoo committed Jul 21, 2022
1 parent 0b07354 commit ed1d8e2
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
17 changes: 12 additions & 5 deletions dvc/repo/experiments/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,18 +253,25 @@ def init(
live_metrics = f"{live_path}.json" if live_path else None
live_plots = os.path.join(live_path, "scalars") if live_path else None

if type == "checkpoint":
outs_key = "checkpoints"
metrics_key = "metrics_persist_no_cache"
plots_key = "plots_persist_no_cache"
else:
outs_key = "outs"
metrics_key = "metrics_no_cache"
plots_key = "plots_no_cache"

stage = repo.stage.create(
name=name,
cmd=context["cmd"],
deps=compact([context.get("code"), context.get("data")]),
params=[{params: None}] if params else None,
metrics_no_cache=compact([context.get("metrics"), live_metrics]),
plots_no_cache=compact([context.get("plots"), live_plots]),
force=force,
**{
"checkpoints"
if type == "checkpoint"
else "outs": compact([models])
outs_key: compact([models]),
metrics_key: compact([context.get("metrics"), live_metrics]),
plots_key: compact([context.get("plots"), live_plots]),
},
)

Expand Down
8 changes: 6 additions & 2 deletions dvc/stage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ def fill_stage_outputs(stage, **kwargs):
keys = [
"outs_persist",
"outs_persist_no_cache",
"metrics_no_cache",
"metrics",
"plots_no_cache",
"metrics_persist",
"metrics_no_cache",
"metrics_persist_no_cache",
"plots",
"plots_persist",
"plots_no_cache",
"plots_persist_no_cache",
"outs_no_cache",
"outs",
"checkpoints",
Expand Down
4 changes: 2 additions & 2 deletions tests/func/experiments/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,12 @@ def test_init_with_type_checkpoint_and_models_plots_provided(
"cmd": "cmd",
"deps": ["data", "src"],
"metrics": [
{"m": {"cache": False}},
{"m": {"cache": False, "persist": True}},
],
"outs": [{"models": {"checkpoint": True}}],
"params": [{"params.yaml": None}],
"plots": [
{"p": {"cache": False}},
{"p": {"cache": False, "persist": True}},
],
}
}
Expand Down

0 comments on commit ed1d8e2

Please sign in to comment.