Skip to content

Commit

Permalink
disallow overriding item and key in the foreach loop (iterative#5104
Browse files Browse the repository at this point in the history
)

* Allow reserving keys in the context to prevent overriding

* Use nullcontext from funcy (3.6 Python)

* Display path in error message
  • Loading branch information
skshetry authored Dec 15, 2020
1 parent 6cad3c9 commit 11c94b1
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 24 deletions.
13 changes: 2 additions & 11 deletions dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,6 @@ def resolve_one(self, key):
return self._each_iter(key)

def _each_iter(self, key):
name = self.name
err_message = f"Could not find '{key}' in foreach group '{self.name}'"
with reraise(KeyError, EntryNotFound(err_message)):
value = self.normalized_iterable[key]
Expand All @@ -420,7 +419,7 @@ def _each_iter(self, key):
if key_str in inserted:
temp_dict[key_str] = key

with self.context.set_temporarily(temp_dict):
with self.context.set_temporarily(temp_dict, reserve=True):
# optimization: item and key can be removed on __exit__() as they
# are top-level values, and are not merged recursively.
# This helps us avoid cloning context, which is slower
Expand All @@ -436,14 +435,6 @@ def _each_iter(self, key):
# generated stages. We do it once when accessing do_definition.
return entry.resolve_stage(skip_checks=True)
except ContextError as exc:
# pylint: disable=no-member
if isinstance(exc, MergeError) and exc.key in inserted:
raise ResolveError(
f"attempted to redefine '{exc.key}' "
f"in stage '{generated}' generated through 'foreach'"
)
format_and_raise(
exc,
f"stage '{generated}' (gen. from '{name}')",
self.relpath,
exc, f"stage '{generated}'", self.relpath,
)
60 changes: 53 additions & 7 deletions dvc/parsing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import dataclass, field, replace
from typing import Any, List, Optional, Union

from funcy import identity, lfilter
from funcy import identity, lfilter, nullcontext, select

from dvc.exceptions import DvcException
from dvc.parsing.interpolate import (
Expand All @@ -21,6 +21,7 @@
)
from dvc.path_info import PathInfo
from dvc.utils import relpath
from dvc.utils.humanize import join
from dvc.utils.serialize import LOADERS

logger = logging.getLogger(__name__)
Expand All @@ -31,6 +32,18 @@ class ContextError(DvcException):
pass


class ReservedKeyError(ContextError):
def __init__(self, keys, path=None):
self.keys = keys
self.path = path

n = "key" + ("s" if len(keys) > 1 else "")
msg = f"attempted to modify reserved {n} {join(keys)}"
if path:
msg += f" in '{path}'"
super().__init__(msg)


class MergeError(ContextError):
def __init__(self, key, new, into):
self.key = key
Expand Down Expand Up @@ -259,9 +272,8 @@ def __setitem__(self, key, value):
return
return super().__setitem__(key, value)

def merge_update(self, *args, overwrite=False):
for d in args:
_merge(self, d, overwrite=overwrite)
def merge_update(self, other, overwrite=False):
_merge(self, other, overwrite=overwrite)

@property
def value(self):
Expand All @@ -285,6 +297,7 @@ def __init__(self, *args, **kwargs):
self._track = False
self._tracked_data = defaultdict(dict)
self.imports = {}
self._reserved_keys = {}

@contextmanager
def track(self):
Expand Down Expand Up @@ -357,6 +370,12 @@ def load_from(cls, tree, path: PathInfo, select_keys=None) -> "Context":
ctx.imports[os.path.abspath(path)] = select_keys or None
return ctx

def merge_update(self, other: "Context", overwrite=False):
matches = select(lambda key: key in other, self._reserved_keys.keys())
if matches:
raise ReservedKeyError(matches)
return super().merge_update(other, overwrite=overwrite)

def merge_from(
self, tree, item: str, wdir: PathInfo, overwrite=False,
):
Expand All @@ -371,7 +390,11 @@ def merge_from(
self.check_loaded(abspath, item, select_keys)

ctx = Context.load_from(tree, path_info, select_keys)
self.merge_update(ctx, overwrite=overwrite)

try:
self.merge_update(ctx, overwrite=overwrite)
except ReservedKeyError as exc:
raise ReservedKeyError(exc.keys, item) from exc

cp = ctx.imports[abspath]
if abspath not in self.imports:
Expand Down Expand Up @@ -424,6 +447,7 @@ def __deepcopy__(self, _):
new = Context(super().__deepcopy__(_))
new.meta = deepcopy(self.meta)
new.imports = deepcopy(self.imports)
new._reserved_keys = deepcopy(self._reserved_keys)
return new

@classmethod
Expand All @@ -432,14 +456,36 @@ def clone(cls, ctx: "Context") -> "Context":
return deepcopy(ctx)

@contextmanager
def set_temporarily(self, to_set):
def reserved(self, *keys: str):
"""Allow reserving some keys so that they cannot be overwritten.
Ideally, we should delegate this to a separate container
and support proper namespacing so that we could support `env` features.
But for now, just `item` and `key`, this should do.
"""
# using dict to make the error messages ordered
new = dict.fromkeys(
[key for key in keys if key not in self._reserved_keys]
)
self._reserved_keys.update(new)
try:
yield
finally:
for key in new.keys():
self._reserved_keys.pop(key)

@contextmanager
def set_temporarily(self, to_set, reserve=False):
cm = self.reserved(*to_set) if reserve else nullcontext()

non_existing = frozenset(to_set.keys() - self.keys())
prev = {key: self[key] for key in to_set if key not in non_existing}
to_set = CtxDict(to_set)
self.update(to_set)

try:
yield
with cm:
yield
finally:
self.update(prev)
for key in non_existing:
Expand Down
25 changes: 19 additions & 6 deletions tests/func/parsing/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import re

import pytest
from funcy import first

from dvc.parsing import ResolveError
from dvc.parsing.context import Context
from dvc.parsing.interpolate import embrace
from dvc.utils.humanize import join
from dvc.utils.serialize import dump_yaml

from . import make_entry_definition, make_foreach_def
Expand Down Expand Up @@ -268,7 +268,13 @@ def test_foreach_do_definition_item_does_not_exist(tmp_dir, dvc, key, loc):


@pytest.mark.parametrize(
"redefine", [{"item": 5}, {"key": 5}, {"item": 5, "key": 10}],
"redefine",
[
{"item": 5},
{"key": 5},
{"item": 5, "key": 10},
{"item": {"epochs": 10}},
],
)
@pytest.mark.parametrize("from_file", [True, False])
def test_item_key_in_generated_stage_vars(tmp_dir, dvc, redefine, from_file):
Expand All @@ -288,10 +294,17 @@ def test_item_key_in_generated_stage_vars(tmp_dir, dvc, redefine, from_file):

with pytest.raises(ResolveError) as exc_info:
definition.resolve_all()
assert str(exc_info.value) == (
f"attempted to redefine '{first(redefine)}' in stage 'build@model1'"
" generated through 'foreach'"
)

message = str(exc_info.value)
assert (
"failed to parse stage 'build@model1' in 'dvc.yaml': "
"attempted to modify reserved"
) in message

key_or_keys = "keys" if len(redefine) > 1 else "key"
assert f"{key_or_keys} {join(redefine)}" in message
if from_file:
assert "in 'test_params.yaml'" in message
assert context == {"foo": "bar"}


Expand Down

0 comments on commit 11c94b1

Please sign in to comment.