Skip to content

Commit

Permalink
Change use and vars to just vars (iterative#4880)
Browse files Browse the repository at this point in the history
The schema is changed of `vars` to accomodate the both use of
`imports` and setting variables locally.

Now:
```
vars:
  - params.yaml
  - foo: foo
    bar: bar
```

Also, same distinction occurs for the stages.
  • Loading branch information
skshetry authored Nov 20, 2020
1 parent 92b17e3 commit 40c3bbd
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 57 deletions.
79 changes: 38 additions & 41 deletions dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import logging
import os
from collections import defaultdict
from collections.abc import Mapping, Sequence
from copy import deepcopy
from itertools import starmap
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Set

from funcy import first, join
from funcy import join

from dvc.dependency.param import ParamsDependency
from dvc.path_info import PathInfo
Expand All @@ -19,7 +18,6 @@
logger = logging.getLogger(__name__)

STAGES_KWD = "stages"
USE_KWD = "use"
VARS_KWD = "vars"
WDIR_KWD = "wdir"
DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE
Expand All @@ -33,24 +31,46 @@

class DataResolver:
def __init__(self, repo: "Repo", wdir: PathInfo, d: dict):
to_import: PathInfo = wdir / d.get(USE_KWD, DEFAULT_PARAMS_FILE)
vars_ = d.get(VARS_KWD, {})
vars_ctx = Context(vars_)
if os.path.exists(to_import):
self.global_ctx_source = to_import

self.data: dict = d
self.wdir = wdir
self.repo = repo
self.imported_files: Set[PathInfo] = set()

to_import: PathInfo = wdir / DEFAULT_PARAMS_FILE
if repo.tree.exists(to_import):
self.imported_files = {to_import}
self.global_ctx = Context.load_from(repo.tree, str(to_import))
else:
self.global_ctx = Context()
self.global_ctx_source = None
logger.debug(
"%s does not exist, it won't be used in parametrization",
to_import,
)

self.global_ctx.merge_update(vars_ctx)
self.data: dict = d
self.wdir = wdir
self.repo = repo
vars_ = d.get(VARS_KWD, [])
self.load_from_vars(
self.global_ctx, vars_, wdir, skip_imports=self.imported_files
)

def load_from_vars(
self,
context: "Context",
vars_: List,
wdir: PathInfo,
skip_imports: Set[PathInfo],
):
for item in vars_:
assert isinstance(item, (str, dict))
if isinstance(item, str):
path = wdir / item
if path in skip_imports:
continue

context.merge_from(self.repo.tree, str(path))
skip_imports.add(path)
else:
context.merge_update(Context(item))

def _resolve_entry(self, name: str, definition):
context = Context.clone(self.global_ctx)
Expand All @@ -77,33 +97,10 @@ def _resolve_stage(self, context: Context, name: str, definition) -> dict:
"Stage %s has different wdir than dvc.yaml file", name
)

contexts = []
params_yaml_file = wdir / DEFAULT_PARAMS_FILE
if self.global_ctx_source != params_yaml_file:
if os.path.exists(params_yaml_file):
contexts.append(
Context.load_from(self.repo.tree, str(params_yaml_file))
)
else:
logger.debug(
"%s does not exist for stage %s", params_yaml_file, name
)

params_deps = definition.get(PARAMS_KWD, [])
params_files = {
wdir / first(item)
for item in params_deps
if item and isinstance(item, dict)
}
for params_file in params_files - {
self.global_ctx_source,
params_yaml_file,
}:
contexts.append(
Context.load_from(self.repo.tree, str(params_file))
)

context.merge_update(*contexts)
vars_ = definition.pop(VARS_KWD, [])
self.load_from_vars(
context, vars_, wdir, skip_imports=deepcopy(self.imported_files)
)

logger.trace( # pytype: disable=attribute-error
"Context during resolution of stage %s:\n%s", name, context
Expand Down
8 changes: 8 additions & 0 deletions dvc/parsing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,14 @@ def load_from(cls, tree, file: str) -> "Context":
meta = Meta(source=file)
return cls(loader(file, tree=tree), meta=meta)

def merge_from(self, tree, path, overwrite=False):
if not tree.exists(path):
raise FileNotFoundError(path)

self.merge_update(
Context.load_from(tree, str(path)), overwrite=overwrite
)

@classmethod
def clone(cls, ctx: "Context") -> "Context":
"""Clones given context."""
Expand Down
12 changes: 6 additions & 6 deletions dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dvc import dependency, output
from dvc.hash_info import HashInfo
from dvc.output import CHECKSUMS_SCHEMA, BaseOutput
from dvc.parsing import FOREACH_KWD, IN_KWD, SET_KWD, USE_KWD, VARS_KWD
from dvc.parsing import FOREACH_KWD, IN_KWD, SET_KWD, VARS_KWD
from dvc.stage.params import StageParams

STAGES = "stages"
Expand Down Expand Up @@ -60,14 +60,15 @@

PARAM_PSTAGE_NON_DEFAULT_SCHEMA = {str: [str]}

VARS_SCHEMA = [str, dict]

STAGE_DEFINITION = {
StageParams.PARAM_CMD: str,
Optional(SET_KWD): dict,
Optional(StageParams.PARAM_WDIR): str,
Optional(StageParams.PARAM_DEPS): [str],
Optional(StageParams.PARAM_PARAMS): [
Any(str, PARAM_PSTAGE_NON_DEFAULT_SCHEMA)
],
Optional(StageParams.PARAM_PARAMS): [Any(str, dict)],
Optional(VARS_KWD): VARS_SCHEMA,
Optional(StageParams.PARAM_FROZEN): bool,
Optional(StageParams.PARAM_META): object,
Optional(StageParams.PARAM_DESC): str,
Expand All @@ -87,8 +88,7 @@
SINGLE_PIPELINE_STAGE_SCHEMA = {str: Any(STAGE_DEFINITION, FOREACH_IN)}
MULTI_STAGE_SCHEMA = {
STAGES: SINGLE_PIPELINE_STAGE_SCHEMA,
USE_KWD: str,
VARS_KWD: dict,
VARS_KWD: VARS_SCHEMA,
}

COMPILED_SINGLE_STAGE_SCHEMA = Schema(SINGLE_STAGE_SCHEMA)
Expand Down
23 changes: 13 additions & 10 deletions tests/func/test_stage_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_simple(tmp_dir, dvc):

def test_vars(tmp_dir, dvc):
d = deepcopy(TEMPLATED_DVC_YAML_DATA)
d["vars"] = CONTEXT_DATA
d["vars"] = [CONTEXT_DATA]
resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
resolved_data = deepcopy(RESOLVED_DVC_YAML_DATA)

Expand All @@ -95,14 +95,14 @@ def test_no_params_yaml_and_vars(tmp_dir, dvc):
resolver.resolve()


def test_use(tmp_dir, dvc):
def test_vars_import(tmp_dir, dvc):
"""
Test that different file can be loaded using `use`
instead of default params.yaml.
"""
dump_yaml(tmp_dir / "params2.yaml", CONTEXT_DATA)
d = deepcopy(TEMPLATED_DVC_YAML_DATA)
d["use"] = "params2.yaml"
d["vars"] = ["params2.yaml"]
resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)

resolved_data = deepcopy(RESOLVED_DVC_YAML_DATA)
Expand All @@ -119,8 +119,7 @@ def test_vars_and_params_import(tmp_dir, dvc):
whilst tracking the "used" variables from params.
"""
d = {
"use": DEFAULT_PARAMS_FILE,
"vars": {"dict": {"foo": "foobar"}},
"vars": [DEFAULT_PARAMS_FILE, {"dict": {"foo": "foobar"}}],
"stages": {"stage1": {"cmd": "echo ${dict.foo} ${dict.bar}"}},
}
dump_yaml(tmp_dir / DEFAULT_PARAMS_FILE, {"dict": {"bar": "bar"}})
Expand All @@ -139,12 +138,12 @@ def test_vars_and_params_import(tmp_dir, dvc):
def test_with_params_section(tmp_dir, dvc):
"""Test that params section is also loaded for interpolation"""
d = {
"use": "params.yaml",
"vars": {"dict": {"foo": "foo"}},
"vars": [DEFAULT_PARAMS_FILE, {"dict": {"foo": "foo"}}],
"stages": {
"stage1": {
"cmd": "echo ${dict.foo} ${dict.bar} ${dict.foobar}",
"params": [{"params.json": ["value1"]}],
"vars": ["params.json"],
},
},
}
Expand Down Expand Up @@ -177,6 +176,7 @@ def test_stage_with_wdir(tmp_dir, dvc):
"cmd": "echo ${dict.foo} ${dict.bar}",
"params": ["value1"],
"wdir": "data",
"vars": [DEFAULT_PARAMS_FILE],
},
},
}
Expand Down Expand Up @@ -217,6 +217,7 @@ def test_with_templated_wdir(tmp_dir, dvc):
"cmd": "echo ${dict.foo} ${dict.bar}",
"params": ["value1"],
"wdir": "${dict.ws}",
"vars": [DEFAULT_PARAMS_FILE],
},
},
}
Expand Down Expand Up @@ -295,7 +296,7 @@ def test_foreach_loop_dict(tmp_dir, dvc):

def test_foreach_loop_templatized(tmp_dir, dvc):
params = {"models": {"us": {"thresh": 10}}}
vars_ = {"models": {"gb": {"thresh": 15}}}
vars_ = [{"models": {"gb": {"thresh": 15}}}]
dump_yaml(tmp_dir / DEFAULT_PARAMS_FILE, params)
d = {
"vars": vars_,
Expand Down Expand Up @@ -401,7 +402,7 @@ def test_set_with_foreach_and_on_stage_definition(tmp_dir, dvc):
dump_json(tmp_dir / "params.json", iterable)

d = {
"use": "params.json",
"vars": ["params.json"],
"stages": {
"build": {
"set": {"data": "${models}"},
Expand Down Expand Up @@ -436,11 +437,12 @@ def test_resolve_local_tries_to_load_globally_used_files(tmp_dir, dvc):
dump_json(tmp_dir / "params.json", iterable)

d = {
"use": "params.json",
"vars": ["params.json"],
"stages": {
"build": {
"cmd": "command --value ${bar}",
"params": [{"params.json": ["foo"]}],
"vars": ["params.json"],
},
},
}
Expand All @@ -467,6 +469,7 @@ def test_resolve_local_tries_to_load_globally_used_params_yaml(tmp_dir, dvc):
"build": {
"cmd": "command --value ${bar}",
"params": [{"params.yaml": ["foo"]}],
"vars": ["params.yaml"],
},
},
}
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from dvc.parsing import DEFAULT_PARAMS_FILE
from dvc.parsing.context import Context, CtxDict, CtxList, Value
from dvc.tree.local import LocalTree
from dvc.utils.serialize import dump_yaml
Expand Down Expand Up @@ -379,3 +380,9 @@ def test_resolve_resolves_dict_keys():
assert context.resolve({"${dct.foo}": {"persist": "${dct.persist}"}}) == {
"foobar": {"persist": True}
}


def test_merge_from_raises_if_file_not_exist(tmp_dir, dvc):
context = Context(foo="bar")
with pytest.raises(FileNotFoundError):
context.merge_from(dvc.tree, DEFAULT_PARAMS_FILE)

0 comments on commit 40c3bbd

Please sign in to comment.