Skip to content

Commit

Permalink
dvc: retain formatting and comments in .dvc files (iterative#1885)
Browse files Browse the repository at this point in the history
* test: remove never used copy of load_stage_file()

* dvc: save comments and custom formatting in stage files

* Add a couple of TODOs

* test: test repro doesn't clear comments and formatting in .dvc

As part of this I started writing pytest-like tests and added commonly
used fixtures.

* test: simplify and prettify dvc formatting retained test

Also test that line comments after md5 fields are retained.

* test: fix tests

* dvc: fix stage and metrics opens not closed

* dvc: always use utf-8 in stage files

It used to encode in system default encoding. Not sure if PyYAML was
handling this silently or we were just lucky all the way.

* test: unit test apply_diff()

* test: remove TODOs and ASK

* dvc: polish dvc-formatting PR

* dvc: move stage load/dump utils to a new stage utils file

* dvc: clean up apply_diff()

* test: test stage md5 ignores comments

* test: remove pytests starting marker
  • Loading branch information
Suor authored and efiop committed Apr 19, 2019
1 parent b057c81 commit 0c86739
Show file tree
Hide file tree
Showing 22 changed files with 356 additions and 128 deletions.
4 changes: 2 additions & 2 deletions dvc/command/tag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import yaml
import logging

from dvc.utils import to_yaml_string
from dvc.exceptions import DvcException
from dvc.command.base import CmdBase, fix_subparsers, append_doc_link

Expand Down Expand Up @@ -50,7 +50,7 @@ def run(self):
recursive=self.args.recursive,
)
if tags:
logger.info(yaml.dump(tags, default_flow_style=False))
logger.info(to_yaml_string(tags))
except DvcException:
logger.exception("failed list tags")
return 1
Expand Down
5 changes: 3 additions & 2 deletions dvc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,12 @@ def __init__(self):


class StageFileCorruptedError(DvcException):
def __init__(self, path):
def __init__(self, path, cause=None):
path = os.path.relpath(path)
super(StageFileCorruptedError, self).__init__(
"unable to read stage file: {} "
"YAML file structure is corrupted".format(path)
"YAML file structure is corrupted".format(path),
cause=cause,
)


Expand Down
12 changes: 8 additions & 4 deletions dvc/repo/metrics/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,14 @@ def _read_metrics(self, metrics, branch):
branch=branch,
)
else:
fd = self.tree.open(out.path)
metric = _read_metric(
fd, typ=typ, xpath=xpath, rel_path=out.rel_path, branch=branch
)
with self.tree.open(out.path) as fd:
metric = _read_metric(
fd,
typ=typ,
xpath=xpath,
rel_path=out.rel_path,
branch=branch,
)

if not metric:
continue
Expand Down
22 changes: 15 additions & 7 deletions dvc/stage.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import unicode_literals

from dvc.utils.compat import str, open
from dvc.utils.compat import str

import copy
import re
import os
import yaml
import subprocess
import logging

Expand All @@ -15,7 +15,9 @@
import dvc.dependency as dependency
import dvc.output as output
from dvc.exceptions import DvcException
from dvc.utils import dict_md5, fix_env, load_stage_file_fobj
from dvc.utils import dict_md5, fix_env
from dvc.utils.collections import apply_diff
from dvc.utils.stage import load_stage_fd, dump_stage_file


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -146,6 +148,7 @@ def __init__(
md5=None,
locked=False,
tag=None,
state=None,
):
if deps is None:
deps = []
Expand All @@ -161,6 +164,7 @@ def __init__(
self.md5 = md5
self.locked = locked
self.tag = tag
self._state = state or {}

def __repr__(self):
return "Stage: '{path}'".format(
Expand Down Expand Up @@ -565,7 +569,11 @@ def load(repo, fname):
Stage._check_dvc_filename(fname)
Stage._check_isfile(repo, fname)

d = load_stage_file_fobj(repo.tree.open(fname), fname)
with repo.tree.open(fname) as fd:
d = load_stage_fd(fd, fname)
# Making a deepcopy since the original structure
# looses keys in deps and outs load
state = copy.deepcopy(d)

Stage.validate(d, fname=os.path.relpath(fname))
path = os.path.abspath(fname)
Expand All @@ -582,6 +590,7 @@ def load(repo, fname):
md5=d.get(Stage.PARAM_MD5),
locked=d.get(Stage.PARAM_LOCKED, False),
tag=tag,
state=state,
)

stage.deps = dependency.loadd_from(stage, d.get(Stage.PARAM_DEPS, []))
Expand Down Expand Up @@ -618,9 +627,8 @@ def dump(self):
)
)
d = self.dumpd()

with open(fname, "w") as fd:
yaml.safe_dump(d, fd, default_flow_style=False)
apply_diff(d, self._state)
dump_stage_file(fname, self._state)

self.repo.scm.track_file(os.path.relpath(fname))

Expand Down
23 changes: 10 additions & 13 deletions dvc/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from __future__ import unicode_literals

import yaml
from dvc.utils.compat import str, builtin_str, open, cast_bytes_py2
from dvc.utils.compat import str, builtin_str, open, cast_bytes_py2, StringIO

import os
import sys
Expand All @@ -19,7 +18,7 @@
import re
import logging

from yaml.scanner import ScannerError
from ruamel.yaml import YAML


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -243,18 +242,16 @@ def current_timestamp():
return int(nanotime.timestamp(time.time()))


def load_stage_file(path):
with open(path, "r") as fobj:
return load_stage_file_fobj(fobj, path)
def from_yaml_string(s):
return YAML().load(StringIO(s))


def load_stage_file_fobj(fobj, path):
from dvc.exceptions import StageFileCorruptedError

try:
return yaml.safe_load(fobj) or {}
except ScannerError:
raise StageFileCorruptedError(path)
def to_yaml_string(data):
stream = StringIO()
yaml = YAML()
yaml.default_flow_style = False
yaml.dump(data, stream)
return stream.getvalue()


def dvc_walk(
Expand Down
43 changes: 42 additions & 1 deletion dvc/utils/collections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,47 @@
from __future__ import unicode_literals
from __future__ import absolute_import, unicode_literals
from dvc.utils.compat import Mapping


# just simple check for Nones and emtpy strings
def compact(args):
return list(filter(bool, args))


def apply_diff(src, dest):
"""Recursively apply changes from stc to dest"""
Seq = (list, tuple)
Container = (Mapping, list, tuple)

def is_same_type(a, b):
return any(
isinstance(a, t) and isinstance(b, t)
for t in [str, Mapping, Seq, bool]
)

if isinstance(src, Mapping) and isinstance(dest, Mapping):
for key, value in src.items():
if isinstance(value, Container) and is_same_type(
value, dest.get(key)
):
apply_diff(value, dest[key])
elif key not in dest or value != dest[key]:
dest[key] = value
for key in set(dest) - set(src):
del dest[key]
elif isinstance(src, Seq) and isinstance(dest, Seq):
if len(src) != len(dest):
dest[:] = src
else:
for i, value in enumerate(src):
if isinstance(value, Container) and is_same_type(
value, dest[i]
):
apply_diff(value, dest[i])
elif value != dest[i]:
dest[i] = value
else:
raise AssertionError(
"Can't apply diff from {} to {}".format(
src.__class__.__name__, dest.__class__.__name__
)
)
25 changes: 23 additions & 2 deletions dvc/utils/compat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Handle import compatibility between Python 2 and Python 3"""
from __future__ import absolute_import

import sys
import os
Expand Down Expand Up @@ -92,12 +93,12 @@ def _makedirs(name, mode=0o777, exist_ok=False):

if is_py2:
from urlparse import urlparse, urljoin # noqa: F401
from StringIO import StringIO # noqa: F401
from io import BytesIO # noqa: F401
from BaseHTTPServer import HTTPServer # noqa: F401
from SimpleHTTPServer import SimpleHTTPRequestHandler # noqa: F401
import ConfigParser # noqa: F401
from io import open # noqa: F401
from pathlib2 import Path # noqa: F401
from collections import Mapping # noqa: F401

builtin_str = str # noqa: F821
bytes = str # noqa: F821
Expand All @@ -109,7 +110,26 @@ def _makedirs(name, mode=0o777, exist_ok=False):
cast_bytes_py2 = cast_bytes
makedirs = _makedirs

import StringIO
import io

class StringIO(StringIO.StringIO):
def __enter__(self):
return self

def __exit__(self, *args):
self.close()

class BytesIO(io.BytesIO):
def __enter__(self):
return self

def __exit__(self, *args):
self.close()


elif is_py3:
from pathlib import Path # noqa: F401
from os import makedirs # noqa: F401
from urllib.parse import urlparse, urljoin # noqa: F401
from io import StringIO, BytesIO # noqa: F401
Expand All @@ -118,6 +138,7 @@ def _makedirs(name, mode=0o777, exist_ok=False):
SimpleHTTPRequestHandler, # noqa: F401
) # noqa: F401
import configparser as ConfigParser # noqa: F401
from collections.abc import Mapping # noqa: F401

builtin_str = str # noqa: F821
str = str # noqa: F821
Expand Down
26 changes: 26 additions & 0 deletions dvc/utils/stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dvc.utils.compat import open

from ruamel.yaml.error import YAMLError
from ruamel.yaml import YAML

from dvc.exceptions import StageFileCorruptedError


def load_stage_file(path):
with open(path, "r", encoding="utf-8") as fd:
return load_stage_fd(fd, path)


def load_stage_fd(fd, path):
try:
yaml = YAML()
return yaml.load(fd) or {}
except YAMLError as exc:
raise StageFileCorruptedError(path, cause=exc)


def dump_stage_file(path, data):
with open(path, "w", encoding="utf-8") as fd:
yaml = YAML()
yaml.default_flow_style = False
yaml.dump(data, fd)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ PyInstaller==3.3.1
colorama>=0.3.9
configobj>=5.0.6
networkx>=2.1
pyyaml>=3.12
gitpython>=2.1.8
setuptools>=34.0.0
nanotime>=0.5.2
Expand All @@ -29,3 +28,5 @@ treelib>=1.5.5
inflect>=2.1.0
humanize>=0.5.1
dulwich>=0.19.11
ruamel.yaml>=0.15.91
pathlib2==2.3.3; python_version == "2.7"
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def run(self):
"colorama>=0.3.9",
"configobj>=5.0.6",
"networkx>=2.1",
"pyyaml>=3.12",
"gitpython>=2.1.8",
"setuptools>=34.0.0",
"nanotime>=0.5.2",
Expand All @@ -59,6 +58,7 @@ def run(self):
"inflect>=2.1.0",
"humanize>=0.5.1",
"dulwich>=0.19.11",
"ruamel.yaml==0.15.91",
]

# Extra dependencies for remote integrations
Expand All @@ -85,7 +85,7 @@ def run(self):
"azure": azure,
"ssh": ssh,
# NOTE: https://github.com/inveniosoftware/troubleshooting/issues/1
':python_version=="2.7"': ["futures"],
':python_version=="2.7"': ["futures", "pathlib2"],
},
keywords="data science, data version control, machine learning",
python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*",
Expand Down
52 changes: 52 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
from git import Repo
from git.exc import GitCommandNotFound

from dvc.repo import Repo as DvcRepo
from .basic_env import TestDirFixture, logger


@pytest.fixture(autouse=True)
def debug():
logger.setLevel("DEBUG")


# Wrap class like fixture as pytest-like one to avoid code duplication
@pytest.fixture
def repo_dir():
old_fixture = TestDirFixture()
old_fixture.setUp()
try:
yield old_fixture
finally:
old_fixture.tearDown()


# NOTE: this duplicates code from GitFixture,
# would fix itself once class-based fixtures are removed
@pytest.fixture
def git(repo_dir):
# NOTE: handles EAGAIN error on BSD systems (osx in our case).
# Otherwise when running tests you might get this exception:
#
# GitCommandNotFound: Cmd('git') not found due to:
# OSError('[Errno 35] Resource temporarily unavailable')
retries = 5
while retries:
try:
git = Repo.init()
except GitCommandNotFound:
retries -= 1
continue
break

git.index.add([repo_dir.CODE])
git.index.commit("add code")
return git


@pytest.fixture
def dvc(repo_dir, git):
dvc = DvcRepo.init(repo_dir._root_dir)
dvc.scm.commit("init dvc")
return dvc
Loading

0 comments on commit 0c86739

Please sign in to comment.