Skip to content

Commit

Permalink
add tests for metrics wrong type and some other edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Feb 19, 2019
1 parent e9b5725 commit b79085e
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 57 deletions.
3 changes: 1 addition & 2 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import unicode_literals

from dvc.utils.compat import str, builtin_str, open
from dvc.utils.compat import str

import collections
import os
import dvc.prompt as prompt
import dvc.logger as logger
Expand Down
7 changes: 5 additions & 2 deletions dvc/repo/metrics/modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


def modify(repo, path, typ=None, xpath=None, delete=False):
supported_types = ["raw", "json", "csv", "tsv", "hcsv", "htsv"]
outs = repo.find_outs_by_path(path)
assert len(outs) == 1
out = outs[0]
Expand All @@ -15,8 +16,10 @@ def modify(repo, path, typ=None, xpath=None, delete=False):
if typ is not None:
typ = typ.lower().strip()
if typ not in ["raw", "json", "csv", "tsv", "hcsv", "htsv"]:
msg = "metric type '{}' is not supported"
raise DvcException(msg.format(typ))
msg = "metric type '{typ}' is not supported, must be one of [{types}]"
raise DvcException(
msg.format(typ=typ, types=", ".join(supported_types))
)
if not isinstance(out.metric, dict):
out.metric = {}
out.metric[out.PARAM_METRIC_TYPE] = typ
Expand Down
78 changes: 52 additions & 26 deletions dvc/repo/metrics/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,69 +15,95 @@ def _read_metric_json(fd, json_path):
return [x.value for x in parser.find(json.load(fd))]


def _get_values(row):
if isinstance(row, dict):
return list(row.values())
else:
return row


def _do_read_metric_xsv(reader, row, col):
if col is not None and row is not None:
return [reader[row][col]]
elif col is not None:
return [r[col] for r in reader]
elif row is not None:
return reader[row]
return None
return _get_values(reader[row])
return [_get_values(r) for r in reader]


def _read_metric_hxsv(fd, hxsv_path, delimiter):
row, col = hxsv_path.split(",")
row = int(row)
indices = hxsv_path.split(",")
row = indices[0]
row = int(row) if row else None
col = indices[1] if len(indices) > 1 and indices[1] else None
reader = list(csv.DictReader(fd, delimiter=builtin_str(delimiter)))
return _do_read_metric_xsv(reader, row, col)


def _read_metric_xsv(fd, xsv_path, delimiter):
row, col = xsv_path.split(",")
row = int(row)
col = int(col)
indices = xsv_path.split(",")
row = indices[0]
row = int(row) if row else None
col = int(indices[1]) if len(indices) > 1 and indices[1] else None
reader = list(csv.reader(fd, delimiter=builtin_str(delimiter)))
return _do_read_metric_xsv(reader, row, col)


def _read_metric(path, typ=None, xpath=None):
ret = None
def _read_typed_metric(typ, xpath, fd):
if typ == "json":
ret = _read_metric_json(fd, xpath)
elif typ == "csv":
ret = _read_metric_xsv(fd, xpath, ",")
elif typ == "tsv":
ret = _read_metric_xsv(fd, xpath, "\t")
elif typ == "hcsv":
ret = _read_metric_hxsv(fd, xpath, ",")
elif typ == "htsv":
ret = _read_metric_hxsv(fd, xpath, "\t")
else:
ret = fd.read().strip()
return ret


def _read_metric(path, typ=None, xpath=None, branch=None):
ret = None
if not os.path.exists(path):
return ret

typ = typ.lower().strip() if typ else typ
xpath = xpath.strip() if xpath else xpath
try:
with open(path, "r") as fd:
if typ == "json":
ret = _read_metric_json(fd, xpath)
elif typ == "csv":
ret = _read_metric_xsv(fd, xpath, ",")
elif typ == "tsv":
ret = _read_metric_xsv(fd, xpath, "\t")
elif typ == "hcsv":
ret = _read_metric_hxsv(fd, xpath, ",")
elif typ == "htsv":
ret = _read_metric_hxsv(fd, xpath, "\t")
else:
if not xpath:
ret = fd.read().strip()
else:
ret = _read_typed_metric(typ, xpath, fd)
# Json path library has to be replaced or wrapped in
# order to fix this too broad except clause.
except Exception:
logger.warning(
"unable to read metric in '{}'".format(path), parse_exception=True
"unable to read metric in '{}' in branch '{}'".format(
path, branch
),
parse_exception=True,
)

return ret


def _collect_metrics(self, path, recursive, typ, xpath):
def _collect_metrics(self, path, recursive, typ, xpath, branch):
outs = [out for stage in self.stages() for out in stage.outs]

if path:
try:
outs = self.find_outs_by_path(path, outs=outs, recursive=recursive)
except OutputNotFoundError:
logger.debug(
"stage file not for found for '{}' in branch '{}'".format(
path, branch
)
)
return []

res = []
Expand All @@ -97,7 +123,7 @@ def _collect_metrics(self, path, recursive, typ, xpath):
return res


def _read_metrics(self, metrics):
def _read_metrics(self, metrics, branch):
res = {}
for out, typ, xpath in metrics:
assert out.scheme == "local"
Expand All @@ -106,7 +132,7 @@ def _read_metrics(self, metrics):
else:
path = out.path

metric = _read_metric(path, typ=typ, xpath=xpath)
metric = _read_metric(path, typ=typ, xpath=xpath, branch=branch)
if not metric:
continue

Expand All @@ -128,8 +154,8 @@ def show(
for branch in self.scm.brancher(
all_branches=all_branches, all_tags=all_tags
):
entries = _collect_metrics(self, path, recursive, typ, xpath)
metrics = _read_metrics(self, entries)
entries = _collect_metrics(self, path, recursive, typ, xpath, branch)
metrics = _read_metrics(self, entries, branch)
if metrics:
res[branch] = metrics

Expand Down
Loading

0 comments on commit b79085e

Please sign in to comment.