Skip to content

Commit

Permalink
experiments: handle exp run from within subdirs/subrepos (iterative…
Browse files Browse the repository at this point in the history
…#5093)

* experiments: add test cases for stage in subdir & subrepo

* git: fix relpath handling in dulwich add()

* repro: move exp file git staging from stage.run into repro

* experiments: support running experiments from outside DVC root

* experiments: setup logger inside multiprocessing (executor) context

* git: fix dulwich add in submodules
  • Loading branch information
pmrowla authored Dec 15, 2020
1 parent 84a3a59 commit 6b9d842
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 14 deletions.
6 changes: 5 additions & 1 deletion dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _update_params(self, params: dict):
logger.debug("Using experiment params '%s'", params)

for params_fname in params:
path = PathInfo(self.repo.root_dir) / params_fname
path = PathInfo(params_fname)
suffix = path.suffix.lower()
modify_data = MODIFIERS[suffix]
with modify_data(path, tree=self.repo.tree) as data:
Expand Down Expand Up @@ -496,6 +496,8 @@ def _reproduce(

manager = Manager()
pid_q = manager.Queue()

rel_cwd = relpath(os.getcwd(), self.repo.root_dir)
with ProcessPoolExecutor(max_workers=jobs) as workers:
futures = {}
for rev, executor in executors.items():
Expand All @@ -505,6 +507,8 @@ def _reproduce(
pid_q,
rev,
name=executor.name,
rel_cwd=rel_cwd,
log_level=logger.getEffectiveLevel(),
)
futures[future] = (rev, executor)

Expand Down
24 changes: 20 additions & 4 deletions dvc/repo/experiments/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,9 @@ def reproduce(
dvc_dir: str,
queue: "Queue",
rev: str,
cwd: Optional[str] = None,
rel_cwd: Optional[str] = None,
name: Optional[str] = None,
log_level: Optional[int] = None,
) -> Tuple[Optional[str], bool]:
"""Run dvc repro and return the result.
Expand All @@ -211,6 +212,7 @@ def reproduce(
unchanged = []

queue.put((rev, os.getpid()))
cls._set_log_level(log_level)

def filter_pipeline(stages):
unchanged.extend(
Expand All @@ -223,9 +225,11 @@ def filter_pipeline(stages):
try:
dvc = Repo(dvc_dir)
old_cwd = os.getcwd()
new_cwd = cwd if cwd else dvc.root_dir
os.chdir(new_cwd)
logger.debug("Running repro in '%s'", cwd)
if rel_cwd:
os.chdir(os.path.join(dvc.root_dir, rel_cwd))
else:
os.chdir(dvc.root_dir)
logger.debug("Running repro in '%s'", os.getcwd())

args_path = os.path.join(
dvc.tmp_dir, BaseExecutor.PACKED_ARGS_FILE
Expand Down Expand Up @@ -321,6 +325,18 @@ def commit(cls, scm: "Git", exp_hash: str, exp_name: Optional[str] = None):
scm.set_ref(EXEC_BRANCH, branch, symbolic=True)
return new_rev

@staticmethod
def _set_log_level(level):
from dvc.logger import disable_other_loggers

# When executor.reproduce is run in a multiprocessing child process,
# dvc.main will not be called for that child process so we need to
# setup logging ourselves
dvc_logger = logging.getLogger("dvc")
disable_other_loggers()
if level is not None:
dvc_logger.setLevel(level)


class LocalExecutor(BaseExecutor):
"""Local machine experiment executor."""
Expand Down
12 changes: 12 additions & 0 deletions dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import typing
from functools import partial

Expand All @@ -18,6 +19,7 @@
def _reproduce_stage(stage, **kwargs):
def _run_callback(repro_callback):
_dump_stage(stage)
_track_stage(stage)
repro_callback([stage])

checkpoint_func = kwargs.pop("checkpoint_func", None)
Expand All @@ -42,7 +44,10 @@ def _run_callback(repro_callback):
return []

if not kwargs.get("dry", False):
track = checkpoint_func is not None
_dump_stage(stage)
if track:
_track_stage(stage)

return [stage]

Expand All @@ -54,6 +59,13 @@ def _dump_stage(stage):
dvcfile.dump(stage, update_pipeline=False)


def _track_stage(stage):
for out in stage.outs:
if not out.use_scm_ignore and out.is_in_repo:
stage.repo.scm.track_file(os.fspath(out.path_info))
stage.repo.scm.track_changed_files()


def _get_active_graph(G):
import networkx as nx

Expand Down
36 changes: 31 additions & 5 deletions dvc/scm/git/backend/dulwich.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import os
import stat
from io import BytesIO, StringIO
from typing import Callable, Iterable, Optional, Tuple
from typing import Callable, Dict, Iterable, Optional, Tuple

from dvc.path_info import PathInfo
from dvc.scm.base import SCMError
from dvc.utils import relpath

Expand Down Expand Up @@ -67,8 +68,24 @@ def __init__( # pylint:disable=W0231
except NotGitRepository as exc:
raise SCMError(f"{root_dir} is not a git repository") from exc

self._submodules: Dict[str, "PathInfo"] = self._find_submodules()
self._stashes: dict = {}

def _find_submodules(self) -> Dict[str, "PathInfo"]:
"""Return dict mapping submodule names to submodule paths.
Submodule paths will be relative to Git repo root.
"""
from dulwich.config import ConfigFile, parse_submodules

submodules: Dict[str, "PathInfo"] = {}
config_path = os.path.join(self.root_dir, ".gitmodules")
if os.path.isfile(config_path):
config = ConfigFile.from_path(config_path)
for path, _url, section in parse_submodules(config):
submodules[os.fsdecode(section)] = PathInfo(os.fsdecode(path))
return submodules

def close(self):
self.repo.close()

Expand Down Expand Up @@ -101,8 +118,19 @@ def add(self, paths: Iterable[str]):

files = []
for path in paths:
if not os.path.isabs(path):
path = os.path.join(self.root_dir, path)
if not os.path.isabs(path) and self._submodules:
# NOTE: If path is inside a submodule, Dulwich expects the
# staged paths to be relative to the submodule root (not the
# parent git repo root). We append path to root_dir here so
# that the result of relpath(path, root_dir) is actually the
# path relative to the submodule root.
path_info = PathInfo(path).relative_to(self.root_dir)
for sm_path in self._submodules.values():
if path_info.isin(sm_path):
path = os.path.join(
self.root_dir, path_info.relative_to(sm_path)
)
break
if os.path.isdir(path):
files.extend(walk_files(path))
else:
Expand Down Expand Up @@ -138,8 +166,6 @@ def untracked_files(self) -> Iterable[str]:
raise NotImplementedError

def is_tracked(self, path: str) -> bool:
from dvc.path_info import PathInfo

rel = PathInfo(path).relative_to(self.root_dir).as_posix().encode()
rel_dir = rel + b"/"
for path in self.repo.open_index():
Expand Down
4 changes: 0 additions & 4 deletions dvc/stage/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,5 @@ def _kill_nt(proc):
def _run_callback(stage, callback_func):
stage.save(allow_missing=True)
stage.commit(allow_missing=True)
for out in stage.outs:
if not out.use_scm_ignore and out.is_in_repo:
stage.repo.scm.track_file(os.fspath(out.path_info))
stage.repo.scm.track_changed_files()
logger.debug("Running checkpoint callback for stage '%s'", stage)
callback_func()
77 changes: 77 additions & 0 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from funcy import first

from dvc.dvcfile import PIPELINE_FILE
from dvc.repo.experiments.utils import exp_refs_by_rev
from dvc.utils.serialize import PythonFileCorruptedError
from tests.func.test_repro_multistage import COPY_SCRIPT
Expand Down Expand Up @@ -436,3 +437,79 @@ def test_list(tmp_dir, scm, dvc, exp_stage):
baseline_a: {ref_info_a.name, ref_info_b.name},
baseline_c: {ref_info_c.name},
}


def test_subdir(tmp_dir, scm, dvc):
subdir = tmp_dir / "dir"
subdir.gen("copy.py", COPY_SCRIPT)
subdir.gen("params.yaml", "foo: 1")

with subdir.chdir():
dvc.run(
cmd="python copy.py params.yaml metrics.yaml",
metrics_no_cache=["metrics.yaml"],
params=["foo"],
name="copy-file",
no_exec=True,
)
scm.add(
[subdir / "dvc.yaml", subdir / "copy.py", subdir / "params.yaml"]
)
scm.commit("init")

results = dvc.experiments.run(PIPELINE_FILE, params=["foo=2"])
assert results

exp = first(results)
ref_info = first(exp_refs_by_rev(scm, exp))

tree = scm.get_tree(exp)
for fname in ["metrics.yaml", "dvc.lock"]:
assert tree.exists(subdir / fname)
with tree.open(subdir / "metrics.yaml") as fobj:
assert fobj.read().strip() == "foo: 2"

assert dvc.experiments.get_exact_name(exp) == ref_info.name
assert scm.resolve_rev(ref_info.name) == exp


def test_subrepo(tmp_dir, scm):
from tests.unit.tree.test_repo import make_subrepo

subrepo = tmp_dir / "dir" / "repo"
make_subrepo(subrepo, scm)

subrepo.gen("copy.py", COPY_SCRIPT)
subrepo.gen("params.yaml", "foo: 1")

with subrepo.chdir():
subrepo.dvc.run(
cmd="python copy.py params.yaml metrics.yaml",
metrics_no_cache=["metrics.yaml"],
params=["foo"],
name="copy-file",
no_exec=True,
)
scm.add(
[
subrepo / "dvc.yaml",
subrepo / "copy.py",
subrepo / "params.yaml",
]
)
scm.commit("init")

results = subrepo.dvc.experiments.run(PIPELINE_FILE, params=["foo=2"])
assert results

exp = first(results)
ref_info = first(exp_refs_by_rev(scm, exp))

tree = scm.get_tree(exp)
for fname in ["metrics.yaml", "dvc.lock"]:
assert tree.exists(subrepo / fname)
with tree.open(subrepo / "metrics.yaml") as fobj:
assert fobj.read().strip() == "foo: 2"

assert subrepo.dvc.experiments.get_exact_name(exp) == ref_info.name
assert scm.resolve_rev(ref_info.name) == exp

0 comments on commit 6b9d842

Please sign in to comment.