Skip to content

Commit

Permalink
Modify yaml/toml using contextmanager, removing boilerplate (iterativ…
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Aug 21, 2020
1 parent 2f2e712 commit 1d196b4
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 81 deletions.
53 changes: 24 additions & 29 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dvc.utils.serialize import (
dump_yaml,
load_yaml,
modify_yaml,
parse_yaml,
parse_yaml_for_update,
)
Expand Down Expand Up @@ -193,30 +194,27 @@ def _dump_lockfile(self, stage):
self._lockfile.dump(stage)

def _dump_pipeline_file(self, stage):
data = {}
if self.exists():
with open(self.path) as fd:
data = parse_yaml_for_update(fd.read(), self.path)
else:
logger.info("Creating '%s'", self.relpath)
open(self.path, "w+").close()

data["stages"] = data.get("stages", {})
stage_data = serialize.to_pipeline_file(stage)
existing_entry = stage.name in data["stages"]

action = "Modifying" if existing_entry else "Adding"
logger.info("%s stage '%s' in '%s'", action, stage.name, self.relpath)
with modify_yaml(self.path, tree=self.repo.tree) as data:
if not data:
logger.info("Creating '%s'", self.relpath)

if existing_entry:
orig_stage_data = data["stages"][stage.name]
if "meta" in orig_stage_data:
stage_data[stage.name]["meta"] = orig_stage_data["meta"]
apply_diff(stage_data[stage.name], orig_stage_data)
else:
data["stages"].update(stage_data)
data["stages"] = data.get("stages", {})
existing_entry = stage.name in data["stages"]
action = "Modifying" if existing_entry else "Adding"
logger.info(
"%s stage '%s' in '%s'", action, stage.name, self.relpath
)

if existing_entry:
orig_stage_data = data["stages"][stage.name]
if "meta" in orig_stage_data:
stage_data[stage.name]["meta"] = orig_stage_data["meta"]
apply_diff(stage_data[stage.name], orig_stage_data)
else:
data["stages"].update(stage_data)

dump_yaml(self.path, data)
self.repo.scm.track_file(self.relpath)

@property
Expand Down Expand Up @@ -281,21 +279,18 @@ def load(self):

def dump(self, stage, **kwargs):
stage_data = serialize.to_lockfile(stage)
if not self.exists():
modified = True
logger.info("Generating lock file '%s'", self.relpath)
data = stage_data
open(self.path, "w+").close()
else:
with self.repo.tree.open(self.path, "r") as fd:
data = parse_yaml_for_update(fd.read(), self.path)

with modify_yaml(self.path, tree=self.repo.tree) as data:
if not data:
logger.info("Generating lock file '%s'", self.relpath)

modified = data.get(stage.name, {}) != stage_data.get(
stage.name, {}
)
if modified:
logger.info("Updating lock file '%s'", self.relpath)
data.update(stage_data)
dump_yaml(self.path, data)

if modified:
self.repo.scm.track_file(self.relpath)

Expand Down
21 changes: 4 additions & 17 deletions dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import re
import tempfile
from collections import defaultdict
from collections.abc import Mapping
from concurrent.futures import (
ProcessPoolExecutor,
Expand Down Expand Up @@ -213,12 +212,7 @@ def _unpack_args(self, tree=None):

def _update_params(self, params: dict):
"""Update experiment params files with the specified values."""
from dvc.utils.serialize import (
dump_toml,
dump_yaml,
parse_toml_for_update,
parse_yaml_for_update,
)
from dvc.utils.serialize import MODIFIERS

logger.debug("Using experiment params '%s'", params)

Expand All @@ -231,19 +225,12 @@ def _update(dict_, other):
dict_[key] = value
return dict_

loaders = defaultdict(lambda: parse_yaml_for_update)
loaders.update({".toml": parse_toml_for_update})
dumpers = defaultdict(lambda: dump_yaml)
dumpers.update({".toml": dump_toml})

for params_fname in params:
path = PathInfo(self.exp_dvc.root_dir) / params_fname
with self.exp_dvc.tree.open(path, "r") as fobj:
text = fobj.read()
suffix = path.suffix.lower()
data = loaders[suffix](text, path)
_update(data, params[params_fname])
dumpers[suffix](path, data)
modify_data = MODIFIERS[suffix]
with modify_data(path, tree=self.exp_dvc.tree) as data:
_update(data, params[params_fname])

def _commit(self, exp_hash, check_exists=True, branch=True):
"""Commit stages as an experiment and return the commit SHA."""
Expand Down
58 changes: 27 additions & 31 deletions dvc/scm/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from dvc.utils import fix_env, is_binary, relpath
from dvc.utils.fs import path_isin
from dvc.utils.serialize import dump_yaml, load_yaml
from dvc.utils.serialize import modify_yaml

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -330,36 +330,32 @@ def install(self, use_pre_commit_tool=False):
return

config_path = os.path.join(self.root_dir, ".pre-commit-config.yaml")
config = load_yaml(config_path) if os.path.exists(config_path) else {}

entry = {
"repo": "https://github.com/iterative/dvc",
"rev": "master",
"hooks": [
{
"id": "dvc-pre-commit",
"language_version": "python3",
"stages": ["commit"],
},
{
"id": "dvc-pre-push",
"language_version": "python3",
"stages": ["push"],
},
{
"id": "dvc-post-checkout",
"language_version": "python3",
"stages": ["post-checkout"],
"always_run": True,
},
],
}

if entry in config["repos"]:
return

config["repos"].append(entry)
dump_yaml(config_path, config)
with modify_yaml(config_path) as config:
entry = {
"repo": "https://github.com/iterative/dvc",
"rev": "master",
"hooks": [
{
"id": "dvc-pre-commit",
"language_version": "python3",
"stages": ["commit"],
},
{
"id": "dvc-pre-push",
"language_version": "python3",
"stages": ["push"],
},
{
"id": "dvc-post-checkout",
"language_version": "python3",
"stages": ["post-checkout"],
"always_run": True,
},
],
}

if entry not in config["repos"]:
config["repos"].append(entry)

def cleanup_ignores(self):
for path in self.ignored_paths:
Expand Down
3 changes: 3 additions & 0 deletions dvc/utils/serialize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@

LOADERS = defaultdict(lambda: load_yaml) # noqa: F405
LOADERS.update({".toml": load_toml}) # noqa: F405

MODIFIERS = defaultdict(lambda: modify_yaml) # noqa: F405
MODIFIERS.update({".toml": modify_toml}) # noqa: F405
10 changes: 10 additions & 0 deletions dvc/utils/serialize/_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Common utilities for serialize."""
import os
from contextlib import contextmanager

from dvc.exceptions import DvcException
from dvc.utils import relpath
Expand All @@ -22,3 +24,11 @@ def _dump_data(path, data, dumper, tree=None):
open_fn = tree.open if tree else open
with open_fn(path, "w+", encoding="utf-8") as fd:
dumper(data, fd)


@contextmanager
def _modify_data(path, parser, dumper, tree=None):
exists = tree.exists if tree else os.path.exists
data = _load_data(path, parser=parser, tree=tree) if exists(path) else {}
yield data
dumper(path, data, tree=tree)
10 changes: 9 additions & 1 deletion dvc/utils/serialize/_toml.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from contextlib import contextmanager

import toml
from funcy import reraise

from ._common import ParseError, _dump_data, _load_data
from ._common import ParseError, _dump_data, _load_data, _modify_data


class TOMLFileCorruptedError(ParseError):
Expand Down Expand Up @@ -35,3 +37,9 @@ def _dump(data, stream):

def dump_toml(path, data, tree=None):
return _dump_data(path, data, dumper=_dump, tree=tree)


@contextmanager
def modify_toml(path, tree=None):
with _modify_data(path, parse_toml_for_update, dump_toml, tree=tree) as d:
yield d
12 changes: 9 additions & 3 deletions dvc/utils/serialize/_yaml.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import io
from collections import OrderedDict
from contextlib import contextmanager

from funcy import reraise
from ruamel.yaml import YAML
from ruamel.yaml.error import YAMLError

from ._common import ParseError, _dump_data, _load_data
from ._common import ParseError, _dump_data, _load_data, _modify_data


class YAMLFileCorruptedError(ParseError):
Expand Down Expand Up @@ -60,6 +61,11 @@ def loads_yaml(s, typ="safe"):

def dumps_yaml(d):
stream = io.StringIO()
yaml = _get_yaml()
yaml.dump(d, stream)
_dump(d, stream)
return stream.getvalue()


@contextmanager
def modify_yaml(path, tree=None):
with _modify_data(path, parse_yaml_for_update, dump_yaml, tree=tree) as d:
yield d

0 comments on commit 1d196b4

Please sign in to comment.