Skip to content

Commit

Permalink
checkpoints: set DVC_ROOT environment variable (iterative#4877)
Browse files Browse the repository at this point in the history
* checkpoints: set DVC_ROOT environment variable

* env: add DVC_ROOT, DVC_CHECKPOINT
  • Loading branch information
pmrowla authored Nov 12, 2020
1 parent 556143f commit 07676b2
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 3 deletions.
5 changes: 3 additions & 2 deletions dvc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,13 @@ def make_checkpoint():
import builtins
from time import sleep

from dvc.env import DVC_CHECKPOINT, DVC_ROOT
from dvc.stage.run import CHECKPOINT_SIGNAL_FILE

if os.getenv("DVC_CHECKPOINT") is None:
if os.getenv(DVC_CHECKPOINT) is None:
return

root_dir = Repo.find_root()
root_dir = os.getenv(DVC_ROOT, Repo.find_root())
signal_file = os.path.join(
root_dir, Repo.DVC_DIR, "tmp", CHECKPOINT_SIGNAL_FILE
)
Expand Down
2 changes: 2 additions & 0 deletions dvc/env.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
DVC_CHECKPOINT = "DVC_CHECKPOINT"
DVC_DAEMON = "DVC_DAEMON"
DVC_PAGER = "DVC_PAGER"
DVC_ROOT = "DVC_ROOT"
7 changes: 6 additions & 1 deletion dvc/stage/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import threading
from contextlib import contextmanager

from dvc.env import DVC_CHECKPOINT, DVC_ROOT
from dvc.utils import fix_env

from .decorators import relock_repo, unlocked_repo
Expand Down Expand Up @@ -47,7 +48,7 @@ def cmd_run(stage, *args, checkpoint_func=None, **kwargs):
kwargs = {"cwd": stage.wdir, "env": fix_env(None), "close_fds": True}
if checkpoint_func:
# indicate that checkpoint cmd is being run inside DVC
kwargs["env"].update({"DVC_CHECKPOINT": "1"})
kwargs["env"].update(_checkpoint_env(stage))

if os.name == "nt":
kwargs["shell"] = True
Expand Down Expand Up @@ -118,6 +119,10 @@ def run_stage(stage, dry=False, force=False, checkpoint_func=None, **kwargs):
cmd_run(stage, checkpoint_func=checkpoint_func)


def _checkpoint_env(stage):
return {DVC_CHECKPOINT: "1", DVC_ROOT: stage.repo.root_dir}


@contextmanager
def checkpoint_monitor(stage, callback_func, proc, killed):
if not callback_func:
Expand Down
6 changes: 6 additions & 0 deletions tests/func/experiments/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from funcy import first

import dvc as dvc_module
from dvc.exceptions import DvcException
from dvc.repo.experiments import Experiments

Expand Down Expand Up @@ -62,10 +63,15 @@ def checkpoint_stage(tmp_dir, scm, dvc):


def test_new_checkpoint(tmp_dir, scm, dvc, checkpoint_stage, mocker):
from dvc.env import DVC_CHECKPOINT, DVC_ROOT

new_mock = mocker.spy(dvc.experiments, "new")
env_mock = mocker.spy(dvc_module.stage.run, "_checkpoint_env")
dvc.experiments.run(checkpoint_stage.addressing, params=["foo=2"])

new_mock.assert_called_once()
env_mock.assert_called_once()
assert set(env_mock.return_value.keys()) == {DVC_CHECKPOINT, DVC_ROOT}
assert (tmp_dir / "foo").read_text() == "5"
assert (
tmp_dir / ".dvc" / "experiments" / "metrics.yaml"
Expand Down

0 comments on commit 07676b2

Please sign in to comment.