From 011bd18e9fa62d9d4adad87320beddb639c900a9 Mon Sep 17 00:00:00 2001 From: Saugat Pachhai Date: Fri, 27 Nov 2020 19:05:02 +0545 Subject: [PATCH] parametrization: support loading vars partially (#4982) --- dvc/parsing/__init__.py | 77 +++++++++++++++------ dvc/parsing/context.py | 23 ++++-- tests/func/test_stage_resolver.py | 22 ++++++ tests/func/test_stage_resolver_error_msg.py | 39 +++++++++++ 4 files changed, 137 insertions(+), 24 deletions(-) diff --git a/dvc/parsing/__init__.py b/dvc/parsing/__init__.py index 2460c79cf6..4dbd414d5d 100644 --- a/dvc/parsing/__init__.py +++ b/dvc/parsing/__init__.py @@ -4,9 +4,9 @@ from collections.abc import Mapping, Sequence from copy import deepcopy from itertools import starmap -from typing import TYPE_CHECKING, List, Set +from typing import TYPE_CHECKING, Dict, List, Optional -from funcy import join +from funcy import join, lfilter from dvc.dependency.param import ParamsDependency from dvc.exceptions import DvcException @@ -21,7 +21,6 @@ MergeError, Meta, Node, - ParamsFileNotFound, SetError, ) @@ -44,13 +43,25 @@ JOIN = "@" +class VarsAlreadyLoaded(DvcException): + pass + + class ResolveError(DvcException): pass +def _format_preamble(msg, path, spacing=" "): + return f"failed to parse {msg} in '{path}':{spacing}" + + def format_and_raise(exc, msg, path): - spacing = "\n" if isinstance(exc, (ParseError, MergeError)) else " " - message = f"failed to parse {msg} in '{path}':{spacing}{str(exc)}" + spacing = ( + "\n" + if isinstance(exc, (ParseError, MergeError, VarsAlreadyLoaded)) + else " " + ) + message = _format_preamble(msg, path, spacing) + str(exc) # FIXME: cannot reraise because of how we log "cause" of the exception # the error message is verbose, hence need control over the spacing @@ -70,12 +81,12 @@ def __init__(self, repo: "Repo", wdir: PathInfo, d: dict): self.wdir = wdir self.repo = repo self.tree = self.repo.tree - self.imported_files: Set[str] = set() + self.imported_files: Dict[str, Optional[List[str]]] = {} self.relpath = relpath(self.wdir / "dvc.yaml") to_import: PathInfo = wdir / DEFAULT_PARAMS_FILE if self.tree.exists(to_import): - self.imported_files = {os.path.abspath(to_import)} + self.imported_files = {os.path.abspath(to_import): None} self.global_ctx = Context.load_from(self.tree, to_import) else: self.global_ctx = Context() @@ -89,28 +100,51 @@ def __init__(self, repo: "Repo", wdir: PathInfo, d: dict): self.load_from_vars( self.global_ctx, vars_, wdir, skip_imports=self.imported_files ) - except (ParamsFileNotFound, MergeError) as exc: + except (ContextError, VarsAlreadyLoaded) as exc: format_and_raise(exc, "'vars'", self.relpath) + @staticmethod + def check_loaded(path, item, keys, skip_imports): + if not keys and isinstance(skip_imports[path], list): + raise VarsAlreadyLoaded( + f"cannot load '{item}' as it's partially loaded already" + ) + elif keys and skip_imports[path] is None: + raise VarsAlreadyLoaded( + f"cannot partially load '{item}' as it's already loaded." + ) + elif keys and isinstance(skip_imports[path], list): + if not set(keys).isdisjoint(set(skip_imports[path])): + raise VarsAlreadyLoaded( + f"cannot load '{item}' as it's partially loaded already" + ) + def load_from_vars( self, context: "Context", vars_: List, wdir: PathInfo, - skip_imports: Set[str], + skip_imports: Dict[str, Optional[List[str]]], stage_name: str = None, ): stage_name = stage_name or "" for index, item in enumerate(vars_): assert isinstance(item, (str, dict)) if isinstance(item, str): - path_info = wdir / item + path, _, keys_str = item.partition(":") + keys = lfilter(bool, keys_str.split(",")) + + path_info = wdir / path path = os.path.abspath(path_info) + if path in skip_imports: - continue + if not keys and skip_imports[path] is None: + # allow specifying complete filepath multiple times + continue + self.check_loaded(path, item, keys, skip_imports) - context.merge_from(self.tree, path_info) - skip_imports.add(path) + context.merge_from(self.tree, path_info, select_keys=keys) + skip_imports[path] = keys if keys else None else: joiner = "." if stage_name else "" meta = Meta(source=f"{stage_name}{joiner}vars[{index}]") @@ -151,13 +185,16 @@ def _resolve_stage(self, context: Context, name: str, definition) -> dict: vars_ = definition.pop(VARS_KWD, []) # FIXME: Should `vars` be templatized? - self.load_from_vars( - context, - vars_, - wdir, - skip_imports=deepcopy(self.imported_files), - stage_name=name, - ) + try: + self.load_from_vars( + context, + vars_, + wdir, + skip_imports=deepcopy(self.imported_files), + stage_name=name, + ) + except VarsAlreadyLoaded as exc: + format_and_raise(exc, f"'stages.{name}.vars'", self.relpath) logger.trace( # type: ignore[attr-defined] "Context during resolution of stage %s:\n%s", name, context diff --git a/dvc/parsing/context.py b/dvc/parsing/context.py index fd42d3d94c..16fe5d08ca 100644 --- a/dvc/parsing/context.py +++ b/dvc/parsing/context.py @@ -310,7 +310,7 @@ def select( return node.value if unwrap else node @classmethod - def load_from(cls, tree, path: PathInfo) -> "Context": + def load_from(cls, tree, path: PathInfo, select_keys=None) -> "Context": file = relpath(path) if not tree.exists(path): raise ParamsFileNotFound(f"'{file}' does not exist") @@ -318,11 +318,26 @@ def load_from(cls, tree, path: PathInfo) -> "Context": _, ext = os.path.splitext(file) loader = LOADERS[ext] + data = loader(path, tree=tree) + select_keys = select_keys or [] + if select_keys: + try: + data = {key: data[key] for key in select_keys} + except KeyError as exc: + key, *_ = exc.args + raise ContextError( + f"could not find '{key}' in '{file}'" + ) from exc + meta = Meta(source=file, local=False) - return cls(loader(path, tree=tree), meta=meta) + return cls(data, meta=meta) - def merge_from(self, tree, path: PathInfo, overwrite=False): - self.merge_update(Context.load_from(tree, path), overwrite=overwrite) + def merge_from( + self, tree, path: PathInfo, overwrite=False, select_keys=None, + ): + self.merge_update( + Context.load_from(tree, path, select_keys), overwrite=overwrite + ) @classmethod def clone(cls, ctx: "Context") -> "Context": diff --git a/tests/func/test_stage_resolver.py b/tests/func/test_stage_resolver.py index c1c7b7c3ac..dee9d14731 100644 --- a/tests/func/test_stage_resolver.py +++ b/tests/func/test_stage_resolver.py @@ -502,3 +502,25 @@ def test_vars_relpath_overwrite(tmp_dir, dvc): } resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d) resolver.resolve() + + +@pytest.mark.parametrize("local", [True, False]) +@pytest.mark.parametrize( + "vars_", + [ + ["test_params.yaml:bar", "test_params.yaml:foo"], + ["test_params.yaml:foo,bar"], + ["test_params.yaml"], + ["test_params.yaml", "test_params.yaml"], + ], +) +def test_vars_load_partial(tmp_dir, dvc, local, vars_): + iterable = {"bar": "bar", "foo": "foo"} + dump_yaml(tmp_dir / "test_params.yaml", iterable) + d = {"stages": {"build": {"cmd": "echo ${bar}"}}} + if local: + d["stages"]["build"]["vars"] = vars_ + else: + d["vars"] = vars_ + resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d) + resolver.resolve() diff --git a/tests/func/test_stage_resolver_error_msg.py b/tests/func/test_stage_resolver_error_msg.py index b92c70a769..f07977e13c 100644 --- a/tests/func/test_stage_resolver_error_msg.py +++ b/tests/func/test_stage_resolver_error_msg.py @@ -343,3 +343,42 @@ def test_interpolate_non_string(tmp_dir, repo): "failed to parse 'stages.build.cmd' in 'dvc.yaml':\n" "Cannot interpolate data of type 'dict'" ) + + +@pytest.mark.parametrize("local", [True, False]) +@pytest.mark.parametrize( + "vars_", + [ + ["test_params.yaml", "test_params.yaml:sub1"], + ["test_params.yaml:sub1", "test_params.yaml"], + ["test_params.yaml:sub1", "test_params.yaml:sub1,sub2"], + ], +) +def test_vars_already_loaded_message(tmp_dir, repo, local, vars_): + d = {"stages": {"build": {"cmd": "echo ${sub1} ${sub2}"}}} + dump_yaml("test_params.yaml", {"sub1": "sub1", "sub2": "sub2"}) + if not local: + d["vars"] = vars_ + else: + d["stages"]["build"]["vars"] = vars_ + + with pytest.raises(ResolveError) as exc_info: + resolver = DataResolver(repo, tmp_dir, d) + resolver.resolve() + + assert "partially" in str(exc_info.value) + + +@pytest.mark.parametrize("local", [False, True]) +def test_partial_vars_doesnot_exist(tmp_dir, repo, local): + d = {"stages": {"build": {"cmd": "echo ${sub1} ${sub2}"}}} + dump_yaml("test_params.yaml", {"sub1": "sub1", "sub2": "sub2"}) + vars_ = ["test_params.yaml:sub3"] + if not local: + d["vars"] = vars_ + else: + d["stages"]["build"]["vars"] = vars_ + + with pytest.raises(ResolveError): + resolver = DataResolver(repo, tmp_dir, d) + resolver.resolve()