Skip to content

Commit

Permalink
parametrization: support loading vars partially (iterative#4982)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Nov 27, 2020
1 parent 6d745eb commit 011bd18
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 24 deletions.
77 changes: 57 additions & 20 deletions dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from collections.abc import Mapping, Sequence
from copy import deepcopy
from itertools import starmap
from typing import TYPE_CHECKING, List, Set
from typing import TYPE_CHECKING, Dict, List, Optional

from funcy import join
from funcy import join, lfilter

from dvc.dependency.param import ParamsDependency
from dvc.exceptions import DvcException
Expand All @@ -21,7 +21,6 @@
MergeError,
Meta,
Node,
ParamsFileNotFound,
SetError,
)

Expand All @@ -44,13 +43,25 @@
JOIN = "@"


class VarsAlreadyLoaded(DvcException):
pass


class ResolveError(DvcException):
pass


def _format_preamble(msg, path, spacing=" "):
return f"failed to parse {msg} in '{path}':{spacing}"


def format_and_raise(exc, msg, path):
spacing = "\n" if isinstance(exc, (ParseError, MergeError)) else " "
message = f"failed to parse {msg} in '{path}':{spacing}{str(exc)}"
spacing = (
"\n"
if isinstance(exc, (ParseError, MergeError, VarsAlreadyLoaded))
else " "
)
message = _format_preamble(msg, path, spacing) + str(exc)

# FIXME: cannot reraise because of how we log "cause" of the exception
# the error message is verbose, hence need control over the spacing
Expand All @@ -70,12 +81,12 @@ def __init__(self, repo: "Repo", wdir: PathInfo, d: dict):
self.wdir = wdir
self.repo = repo
self.tree = self.repo.tree
self.imported_files: Set[str] = set()
self.imported_files: Dict[str, Optional[List[str]]] = {}
self.relpath = relpath(self.wdir / "dvc.yaml")

to_import: PathInfo = wdir / DEFAULT_PARAMS_FILE
if self.tree.exists(to_import):
self.imported_files = {os.path.abspath(to_import)}
self.imported_files = {os.path.abspath(to_import): None}
self.global_ctx = Context.load_from(self.tree, to_import)
else:
self.global_ctx = Context()
Expand All @@ -89,28 +100,51 @@ def __init__(self, repo: "Repo", wdir: PathInfo, d: dict):
self.load_from_vars(
self.global_ctx, vars_, wdir, skip_imports=self.imported_files
)
except (ParamsFileNotFound, MergeError) as exc:
except (ContextError, VarsAlreadyLoaded) as exc:
format_and_raise(exc, "'vars'", self.relpath)

@staticmethod
def check_loaded(path, item, keys, skip_imports):
if not keys and isinstance(skip_imports[path], list):
raise VarsAlreadyLoaded(
f"cannot load '{item}' as it's partially loaded already"
)
elif keys and skip_imports[path] is None:
raise VarsAlreadyLoaded(
f"cannot partially load '{item}' as it's already loaded."
)
elif keys and isinstance(skip_imports[path], list):
if not set(keys).isdisjoint(set(skip_imports[path])):
raise VarsAlreadyLoaded(
f"cannot load '{item}' as it's partially loaded already"
)

def load_from_vars(
self,
context: "Context",
vars_: List,
wdir: PathInfo,
skip_imports: Set[str],
skip_imports: Dict[str, Optional[List[str]]],
stage_name: str = None,
):
stage_name = stage_name or ""
for index, item in enumerate(vars_):
assert isinstance(item, (str, dict))
if isinstance(item, str):
path_info = wdir / item
path, _, keys_str = item.partition(":")
keys = lfilter(bool, keys_str.split(","))

path_info = wdir / path
path = os.path.abspath(path_info)

if path in skip_imports:
continue
if not keys and skip_imports[path] is None:
# allow specifying complete filepath multiple times
continue
self.check_loaded(path, item, keys, skip_imports)

context.merge_from(self.tree, path_info)
skip_imports.add(path)
context.merge_from(self.tree, path_info, select_keys=keys)
skip_imports[path] = keys if keys else None
else:
joiner = "." if stage_name else ""
meta = Meta(source=f"{stage_name}{joiner}vars[{index}]")
Expand Down Expand Up @@ -151,13 +185,16 @@ def _resolve_stage(self, context: Context, name: str, definition) -> dict:

vars_ = definition.pop(VARS_KWD, [])
# FIXME: Should `vars` be templatized?
self.load_from_vars(
context,
vars_,
wdir,
skip_imports=deepcopy(self.imported_files),
stage_name=name,
)
try:
self.load_from_vars(
context,
vars_,
wdir,
skip_imports=deepcopy(self.imported_files),
stage_name=name,
)
except VarsAlreadyLoaded as exc:
format_and_raise(exc, f"'stages.{name}.vars'", self.relpath)

logger.trace( # type: ignore[attr-defined]
"Context during resolution of stage %s:\n%s", name, context
Expand Down
23 changes: 19 additions & 4 deletions dvc/parsing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,19 +310,34 @@ def select(
return node.value if unwrap else node

@classmethod
def load_from(cls, tree, path: PathInfo) -> "Context":
def load_from(cls, tree, path: PathInfo, select_keys=None) -> "Context":
file = relpath(path)
if not tree.exists(path):
raise ParamsFileNotFound(f"'{file}' does not exist")

_, ext = os.path.splitext(file)
loader = LOADERS[ext]

data = loader(path, tree=tree)
select_keys = select_keys or []
if select_keys:
try:
data = {key: data[key] for key in select_keys}
except KeyError as exc:
key, *_ = exc.args
raise ContextError(
f"could not find '{key}' in '{file}'"
) from exc

meta = Meta(source=file, local=False)
return cls(loader(path, tree=tree), meta=meta)
return cls(data, meta=meta)

def merge_from(self, tree, path: PathInfo, overwrite=False):
self.merge_update(Context.load_from(tree, path), overwrite=overwrite)
def merge_from(
self, tree, path: PathInfo, overwrite=False, select_keys=None,
):
self.merge_update(
Context.load_from(tree, path, select_keys), overwrite=overwrite
)

@classmethod
def clone(cls, ctx: "Context") -> "Context":
Expand Down
22 changes: 22 additions & 0 deletions tests/func/test_stage_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,25 @@ def test_vars_relpath_overwrite(tmp_dir, dvc):
}
resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
resolver.resolve()


@pytest.mark.parametrize("local", [True, False])
@pytest.mark.parametrize(
"vars_",
[
["test_params.yaml:bar", "test_params.yaml:foo"],
["test_params.yaml:foo,bar"],
["test_params.yaml"],
["test_params.yaml", "test_params.yaml"],
],
)
def test_vars_load_partial(tmp_dir, dvc, local, vars_):
iterable = {"bar": "bar", "foo": "foo"}
dump_yaml(tmp_dir / "test_params.yaml", iterable)
d = {"stages": {"build": {"cmd": "echo ${bar}"}}}
if local:
d["stages"]["build"]["vars"] = vars_
else:
d["vars"] = vars_
resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
resolver.resolve()
39 changes: 39 additions & 0 deletions tests/func/test_stage_resolver_error_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,42 @@ def test_interpolate_non_string(tmp_dir, repo):
"failed to parse 'stages.build.cmd' in 'dvc.yaml':\n"
"Cannot interpolate data of type 'dict'"
)


@pytest.mark.parametrize("local", [True, False])
@pytest.mark.parametrize(
"vars_",
[
["test_params.yaml", "test_params.yaml:sub1"],
["test_params.yaml:sub1", "test_params.yaml"],
["test_params.yaml:sub1", "test_params.yaml:sub1,sub2"],
],
)
def test_vars_already_loaded_message(tmp_dir, repo, local, vars_):
d = {"stages": {"build": {"cmd": "echo ${sub1} ${sub2}"}}}
dump_yaml("test_params.yaml", {"sub1": "sub1", "sub2": "sub2"})
if not local:
d["vars"] = vars_
else:
d["stages"]["build"]["vars"] = vars_

with pytest.raises(ResolveError) as exc_info:
resolver = DataResolver(repo, tmp_dir, d)
resolver.resolve()

assert "partially" in str(exc_info.value)


@pytest.mark.parametrize("local", [False, True])
def test_partial_vars_doesnot_exist(tmp_dir, repo, local):
d = {"stages": {"build": {"cmd": "echo ${sub1} ${sub2}"}}}
dump_yaml("test_params.yaml", {"sub1": "sub1", "sub2": "sub2"})
vars_ = ["test_params.yaml:sub3"]
if not local:
d["vars"] = vars_
else:
d["stages"]["build"]["vars"] = vars_

with pytest.raises(ResolveError):
resolver = DataResolver(repo, tmp_dir, d)
resolver.resolve()

0 comments on commit 011bd18

Please sign in to comment.