Skip to content

Commit

Permalink
experiments: abstract executor tmpdir population using tree upload (i…
Browse files Browse the repository at this point in the history
…terative#4432)

* executor: abstract tmpdir population using tree upload

* experiments: unmark flaky test

* fix wrong tree download

* LocalTree: use os.replace instead of os.rename in _upload
  • Loading branch information
pmrowla authored Aug 21, 2020
1 parent d89c16d commit dac17e2
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 45 deletions.
73 changes: 54 additions & 19 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

from funcy import cached_property, first

from dvc.exceptions import DownloadError, DvcException
from dvc.exceptions import DownloadError, DvcException, UploadError
from dvc.path_info import PathInfo
from dvc.progress import Tqdm
from dvc.repo.experiments.executor import ExperimentExecutor, LocalExecutor
from dvc.scm.git import Git
from dvc.stage.serialize import to_lockfile
from dvc.tree.repo import RepoTree
from dvc.utils import dict_sha256, env2bool, relpath
from dvc.utils.fs import remove

Expand Down Expand Up @@ -52,6 +53,8 @@ def __init__(self, rev):

class BaselineMismatchError(DvcException):
def __init__(self, rev, expected):
if hasattr(rev, "hexsha"):
rev = rev.hexsha
rev_str = f"{rev[:7]}" if rev is not None else "dangling commit"
super().__init__(
f"Experiment derived from '{rev_str}', expected '{expected[:7]}'."
Expand Down Expand Up @@ -236,6 +239,7 @@ def _commit(self, exp_hash, check_exists=True, branch=True):
"""Commit stages as an experiment and return the commit SHA."""
if not self.scm.is_dirty():
raise UnchangedExperimentError(self.scm.get_rev())

rev = self.scm.get_rev()
exp_name = f"{rev[:7]}-{exp_hash}"
if branch:
Expand Down Expand Up @@ -326,19 +330,26 @@ def reproduce(
for rev in revs
}

# setup executors
logger.debug(
"Reproducing experiment revs '%s'",
", ".join((rev[:7] for rev in to_run)),
)

# setup executors - unstash experiment, generate executor, upload
# contents of (unstashed) exp workspace to the executor tree
executors = {}
for rev, baseline_rev in to_run.items():
tree = self.scm.get_tree(rev)
repro_args, repro_kwargs = self._unpack_args(tree)
self._scm_checkout(baseline_rev)
self.scm.repo.git.stash("apply", rev)
repro_args, repro_kwargs = self._unpack_args()
executor = LocalExecutor(
tree,
baseline_rev,
repro_args=repro_args,
repro_kwargs=repro_kwargs,
dvc_dir=self.dvc_dir,
cache_dir=self.repo.cache.local.cache_dir,
)
self._collect_input(executor)
executors[rev] = executor

exec_results = self._reproduce(executors, **kwargs)
Expand Down Expand Up @@ -421,42 +432,66 @@ def _reproduce(self, executors: dict, jobs: Optional[int] = 1) -> dict:

return result

def _collect_input(self, executor: ExperimentExecutor):
"""Copy (upload) input from the experiments workspace to the executor
tree.
"""
logger.debug("Collecting input for '%s'", executor.tmp_dir)
repo_tree = RepoTree(self.exp_dvc)
self._process(
executor.tree,
self.exp_dvc.tree,
executor.collect_files(self.exp_dvc.tree, repo_tree),
)

def _collect_output(self, executor: ExperimentExecutor):
"""Copy (download) output from the executor tree into experiments
workspace.
"""
from dvc.cache.local import _log_exceptions

logger.debug("Collecting output from '%s'", executor.tmp_dir)
dest_tree = self.exp_dvc.tree
src_tree = executor.tree
self._process(
self.exp_dvc.tree,
executor.tree,
executor.collect_output(),
download=True,
)

@staticmethod
def _process(dest_tree, src_tree, collected_files, download=False):
from dvc.cache.local import _log_exceptions

from_infos = []
to_infos = []
names = []
for from_info in executor.collect_output():
for from_info in collected_files:
from_infos.append(from_info)
fname = from_info.relative_to(src_tree.path_info)
names.append(str(fname))
to_info = dest_tree.path_info / fname
to_infos.append(dest_tree.path_info / fname)
logger.debug(f"from '{from_info}' to '{to_info}'")
total = len(from_infos)

func = partial(
_log_exceptions(dest_tree.download, "download"),
dir_mode=dest_tree.dir_mode,
file_mode=dest_tree.file_mode,
)
with Tqdm(total=total, unit="file", desc="Downloading") as pbar:
if download:
func = partial(
_log_exceptions(src_tree.download, "download"),
dir_mode=dest_tree.dir_mode,
file_mode=dest_tree.file_mode,
)
desc = "Downloading"
else:
func = partial(_log_exceptions(dest_tree.upload, "upload"))
desc = "Uploading"

with Tqdm(total=total, unit="file", desc=desc) as pbar:
func = pbar.wrap_fn(func)
# TODO: parallelize this, currently --jobs for repro applies to
# number of repro executors not download threads
with ThreadPoolExecutor(max_workers=1) as dl_executor:
fails = sum(dl_executor.map(func, from_infos, to_infos, names))

if fails:
raise DownloadError(fails)
if download:
raise DownloadError(fails)
raise UploadError(fails)

@scm_locked
def checkout_exp(self, rev):
Expand Down
44 changes: 22 additions & 22 deletions dvc/repo/experiments/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from dvc.tree.base import BaseTree
from dvc.tree.local import LocalTree
from dvc.tree.repo import RepoTree
from dvc.utils import relpath
from dvc.utils.fs import copy_fobj_to_file, makedirs

logger = logging.getLogger(__name__)

Expand All @@ -21,16 +19,14 @@ class ExperimentExecutor:
"""Base class for executing experiments in parallel.
Args:
src_tree: source tree for this experiment.
baseline_rev: baseline revision that this experiment is derived from.
Optional keyword args:
repro_args: Args to be passed into reproduce.
repro_kwargs: Keyword args to be passed into reproduce.
"""

def __init__(self, src_tree: BaseTree, baseline_rev: str, **kwargs):
self.src_tree = src_tree
def __init__(self, baseline_rev: str, **kwargs):
self.baseline_rev = baseline_rev
self.repro_args = kwargs.pop("repro_args", [])
self.repro_kwargs = kwargs.pop("repro_kwargs", {})
Expand Down Expand Up @@ -73,27 +69,29 @@ def unpack_repro_args(path, tree=None):


class LocalExecutor(ExperimentExecutor):
"""Local machine exepriment executor."""
"""Local machine experiment executor."""

def __init__(self, baseline_rev: str, **kwargs):
from dvc.repo import Repo

def __init__(self, src_tree: BaseTree, baseline_rev: str, **kwargs):
dvc_dir = kwargs.pop("dvc_dir")
cache_dir = kwargs.pop("cache_dir")
super().__init__(src_tree, baseline_rev, **kwargs)
super().__init__(baseline_rev, **kwargs)
self.tmp_dir = TemporaryDirectory()
logger.debug("Init local executor in dir '%s'.", self.tmp_dir)

# init empty DVC repo (will be overwritten when input is uploaded)
Repo.init(root_dir=self.tmp_dir.name, no_scm=True)
logger.debug(
"Init local executor in dir '%s' with baseline '%s'.",
self.tmp_dir,
baseline_rev[:7],
)
self.dvc_dir = os.path.join(self.tmp_dir.name, dvc_dir)
try:
for fname in src_tree.walk_files(src_tree.tree_root):
dest = self.path_info / relpath(fname, src_tree.tree_root)
if not os.path.exists(dest.parent):
makedirs(dest.parent)
with src_tree.open(fname, "rb") as fobj:
copy_fobj_to_file(fobj, dest)
except Exception:
self.tmp_dir.cleanup()
raise
self._config(cache_dir)
self._tree = LocalTree(self.dvc, {"url": self.dvc.root_dir})
# override default CACHE_MODE since files must be writable in order
# to run repro
self._tree.CACHE_MODE = 0o644

def _config(self, cache_dir):
local_config = os.path.join(self.dvc_dir, "config.local")
Expand Down Expand Up @@ -151,11 +149,13 @@ def filter_pipeline(stage):

def collect_output(self) -> Iterable["PathInfo"]:
repo_tree = RepoTree(self.dvc)
yield from self.collect_files(self.tree, repo_tree)

@staticmethod
def collect_files(tree: BaseTree, repo_tree: RepoTree):
for fname in repo_tree.walk_files(repo_tree.root_dir, dvcfiles=True):
if not repo_tree.isdvc(fname):
yield self.tree.path_info / fname.relative_to(
repo_tree.root_dir
)
yield tree.path_info / fname.relative_to(repo_tree.root_dir)

def cleanup(self):
logger.debug("Removing tmpdir '%s'", self.tmp_dir)
Expand Down
4 changes: 0 additions & 4 deletions tests/func/experiments/test_experiments.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import pytest

from tests.func.test_repro_multistage import COPY_SCRIPT


# https://github.com/iterative/dvc/issues/4401
@pytest.mark.flaky(max_runs=3, min_passes=1)
def test_new_simple(tmp_dir, scm, dvc, mocker):
tmp_dir.gen("copy.py", COPY_SCRIPT)
tmp_dir.gen("params.yaml", "foo: 1")
Expand Down

0 comments on commit dac17e2

Please sign in to comment.