Skip to content

Commit

Permalink
Fix merge_update not converting data to Node object (iterative#4846)
Browse files Browse the repository at this point in the history
We wrap every data to one of the `Node` object. But, since we were
passing `self.data` instead of `self`, it was skipping the conversion
and instead was just trying to set data as-is, which would have make
our latter assumptions incorrect, as other codebases assume it to be
a `Node`.

I have added asserts in a few places. Also, I noticed that, on a few
test assumptions, when `resolving`, we were not converting values
back to it's original form (`CtxDict` -> `dict`, `CtxList` -> `list`
and `Value` -> _). So, I have added `node.value` property, as it is
a complement of the above fix.

This caused a ripple effect for the changes in the API as `foreach`
now resolved and changed the data to the original form,
which prevents us from tracking the data later in the "field".
So, `resolve`/`select` got a real `unwrap` option that should
reliably work for all kinds of `Node` objects.
  • Loading branch information
skshetry authored Nov 7, 2020
1 parent f4ee0e1 commit 9f1330a
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 61 deletions.
2 changes: 1 addition & 1 deletion dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def each_iter(value, key=DEFAULT_SENTINEL):
suffix = str(key if key is not DEFAULT_SENTINEL else value)
return self._resolve_stage(c, f"{name}-{suffix}", in_data)

iterable = context.resolve(foreach_data)
iterable = context.resolve(foreach_data, unwrap=False)

assert isinstance(iterable, (Sequence, Mapping)) and not isinstance(
iterable, str
Expand Down
75 changes: 60 additions & 15 deletions dvc/parsing/context.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Mapping, MutableMapping, MutableSequence, Sequence
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field, replace
from typing import Any, List, Optional, Union

from funcy import identity
from funcy import identity, rpartial

from dvc.parsing.interpolate import (
get_expression,
Expand All @@ -30,6 +31,7 @@ def _merge(into, update, overwrite):
f"Cannot overwrite as key {key} already exists in {into}"
)
into[key] = val
assert isinstance(into[key], Node)


@dataclass
Expand All @@ -55,24 +57,41 @@ def _default_meta():
return Meta(source=None)


class Node:
def get_sources(self):
raise NotImplementedError

@property
@abstractmethod
def value(self):
pass


@dataclass
class Value:
value: Any
class Value(Node):
_value: Any
meta: Meta = field(
compare=False, default_factory=_default_meta, repr=False
)

def __repr__(self):
return repr(self.value)
return repr(self._value)

def __str__(self) -> str:
return str(self.value)
return str(self._value)

def get_sources(self):
return {self.meta.source: self.meta.path()}

@property
def value(self):
return self._value


PRIMITIVES = (int, float, str, bytes, bool)

class Container:

class Container(Node, ABC):
meta: Meta
data: Union[list, dict]
_key_transform = staticmethod(identity)
Expand All @@ -82,9 +101,9 @@ def __init__(self, meta=None) -> None:

def _convert(self, key, value):
meta = Meta.update_path(self.meta, key)
if value is None or isinstance(value, (int, float, str, bytes, bool)):
if value is None or isinstance(value, PRIMITIVES):
return Value(value, meta=meta)
elif isinstance(value, (CtxList, CtxDict, Value)):
elif isinstance(value, Node):
return value
elif isinstance(value, (list, dict)):
container = CtxDict if isinstance(value, dict) else CtxList
Expand Down Expand Up @@ -150,6 +169,10 @@ def insert(self, index: int, value):
def get_sources(self):
return {self.meta.source: self.meta.path()}

@property
def value(self):
return [node.value for node in self]


class CtxDict(Container, MutableMapping):
def __init__(self, mapping: Mapping = None, meta: Meta = None, **kwargs):
Expand All @@ -169,7 +192,11 @@ def __setitem__(self, key, value):

def merge_update(self, *args, overwrite=False):
for d in args:
_merge(self.data, d, overwrite=overwrite)
_merge(self, d, overwrite=overwrite)

@property
def value(self):
return {key: node.value for key, node in self.items()}


class Context(CtxDict):
Expand Down Expand Up @@ -205,10 +232,21 @@ def tracked(self):
def select(
self, key: str, unwrap=False
): # pylint: disable=arguments-differ
# NOTE: `unwrap` default value is different from `resolve_str`
"""Select the item using key, similar to `__getitem__`
but can track the usage of the data on interpolation
as well and can get from nested data structure by using
"." separated key (eg: "key1.key2.key3")
Args:
key: key to select value from
unwrap: Convert CtxList/CtxDict/Value items to it's original data
Defaults to False. Note that the default is different from
`resolve`.
"""
node = super().select(key)
self._track_data(node)
return node.value if isinstance(node, Value) and unwrap else node
assert isinstance(node, Node)
return node.value if unwrap else node

@classmethod
def load_from(cls, tree, file: str) -> "Context":
Expand Down Expand Up @@ -243,20 +281,27 @@ def set(self, key, value):
self._check_interpolation_collection(key, value)
self[key] = value

def resolve(self, src):
def resolve(self, src, unwrap=True):
"""Recursively resolves interpolation and returns resolved data.
Args:
src: Data (str/list/dict etc.) to resolve
unwrap: Unwrap CtxDict/CtxList/Value to it's original data if
inside `src`. Defaults to True.
>>> c = Context({"three": 3})
>>> c.resolve({"lst": [1, 2, "${three}"]})
{'lst': [1, 2, 3]}
"""
Seq = (list, tuple, set)

resolve = rpartial(self.resolve, unwrap)
if isinstance(src, Mapping):
return {key: self.resolve(value) for key, value in src.items()}
return {key: resolve(value) for key, value in src.items()}
elif isinstance(src, Seq):
return type(src)(map(self.resolve, src))
return type(src)(map(resolve, src))
elif isinstance(src, str):
return self.resolve_str(src)
return self.resolve_str(src, unwrap=unwrap)
return src

def resolve_str(self, src: str, unwrap=True):
Expand Down
113 changes: 71 additions & 42 deletions tests/func/test_stage_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dvc.dependency import _merge_params
from dvc.parsing import DEFAULT_PARAMS_FILE, DataResolver
from dvc.parsing.context import Node
from dvc.path_info import PathInfo
from dvc.utils.serialize import dump_json, dump_yaml

Expand Down Expand Up @@ -44,9 +45,19 @@
}


def recurse_not_a_node(d):
assert not isinstance(d, Node)
if isinstance(d, (list, dict)):
iterable = d if isinstance(d, list) else d.values()
for item in iterable:
assert recurse_not_a_node(item)
return True


def assert_stage_equal(d1, d2):
"""Keeps the params section in order, and then checks for equality."""
for d in [d1, d2]:
assert recurse_not_a_node(d)
for _, stage_d in d.get("stages", {}).items():
params = _merge_params(stage_d.get("params", []))
for k in params:
Expand Down Expand Up @@ -248,12 +259,15 @@ def test_simple_foreach_loop(tmp_dir, dvc):
}

resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert resolver.resolve() == {
"stages": {
f"build-{item}": {"cmd": f"python script.py {item}"}
for item in iterable
}
}
assert_stage_equal(
resolver.resolve(),
{
"stages": {
f"build-{item}": {"cmd": f"python script.py {item}"}
for item in iterable
}
},
)


def test_foreach_loop_dict(tmp_dir, dvc):
Expand All @@ -268,12 +282,15 @@ def test_foreach_loop_dict(tmp_dir, dvc):
}

resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert resolver.resolve() == {
"stages": {
f"build-{key}": {"cmd": f"python script.py {item['thresh']}"}
for key, item in iterable["models"].items()
}
}
assert_stage_equal(
resolver.resolve(),
{
"stages": {
f"build-{key}": {"cmd": f"python script.py {item['thresh']}"}
for key, item in iterable["models"].items()
}
},
)


def test_foreach_loop_templatized(tmp_dir, dvc):
Expand Down Expand Up @@ -319,14 +336,17 @@ def test_set(tmp_dir, dvc, value):
}
}
resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert resolver.resolve() == {
"stages": {
"build": {
"cmd": f"python script.py --thresh {value}",
"always_changed": value,
assert_stage_equal(
resolver.resolve(),
{
"stages": {
"build": {
"cmd": f"python script.py --thresh {value}",
"always_changed": value,
}
}
}
}
},
)


@pytest.mark.parametrize(
Expand All @@ -343,11 +363,14 @@ def test_coll(tmp_dir, dvc, coll):
}
}
resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert resolver.resolve() == {
"stages": {
"build": {"cmd": "python script.py --thresh 10", "outs": coll}
}
}
assert_stage_equal(
resolver.resolve(),
{
"stages": {
"build": {"cmd": "python script.py --thresh 10", "outs": coll}
}
},
)


def test_set_with_foreach(tmp_dir, dvc):
Expand All @@ -362,12 +385,15 @@ def test_set_with_foreach(tmp_dir, dvc):
}
}
resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert resolver.resolve() == {
"stages": {
f"build-{item}": {"cmd": f"command --value {item}"}
for item in items
}
}
assert_stage_equal(
resolver.resolve(),
{
"stages": {
f"build-{item}": {"cmd": f"command --value {item}"}
for item in items
}
},
)


def test_set_with_foreach_and_on_stage_definition(tmp_dir, dvc):
Expand All @@ -388,15 +414,18 @@ def test_set_with_foreach_and_on_stage_definition(tmp_dir, dvc):
},
}
resolver = DataResolver(dvc, PathInfo(str(tmp_dir)), d)
assert resolver.resolve() == {
"stages": {
"build-us": {
"cmd": "command --value 10",
"params": [{"params.json": ["models.us.thresh"]}],
},
"build-gb": {
"cmd": "command --value 15",
"params": [{"params.json": ["models.gb.thresh"]}],
},
}
}
assert_stage_equal(
resolver.resolve(),
{
"stages": {
"build-us": {
"cmd": "command --value 10",
"params": [{"params.json": ["models.us.thresh"]}],
},
"build-gb": {
"cmd": "command --value 15",
"params": [{"params.json": ["models.gb.thresh"]}],
},
}
},
)
Loading

0 comments on commit 9f1330a

Please sign in to comment.