Skip to content

Commit

Permalink
Merge pull request iterative#1733 from mroutis/fix-1716
Browse files Browse the repository at this point in the history
metrics: shows formatted multiline metrics
  • Loading branch information
efiop authored Mar 30, 2019
2 parents 374b75f + 298770e commit b3addbd
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 9 deletions.
18 changes: 16 additions & 2 deletions dvc/command/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,26 @@


def show_metrics(metrics, all_branches=False, all_tags=False):
"""
Args:
metrics (list): Where each element is either a `list`
if an xpath was specified, otherwise a `str`
"""
for branch, val in metrics.items():
if all_branches or all_tags:
logger.info("{}:".format(branch))
logger.info("{branch}:".format(branch=branch))

for fname, metric in val.items():
logger.info("\t{}: {}".format(fname, metric))
lines = metric if type(metric) is list else metric.splitlines()

if len(lines) > 1:
logger.info("\t{fname}:".format(fname=fname))

for line in lines:
logger.info("\t\t{content}".format(content=line))

else:
logger.info("\t{}: {}".format(fname, metric))


class CmdMetricsShow(CmdBase):
Expand Down
96 changes: 94 additions & 2 deletions dvc/repo/metrics/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import dvc.logger as logger
from dvc.exceptions import OutputNotFoundError, BadMetricError, NoMetricsError
from dvc.utils.compat import builtin_str, open
from dvc.utils.compat import builtin_str, open, StringIO, csv_reader


def _read_metric_json(fd, json_path):
Expand Down Expand Up @@ -66,13 +66,73 @@ def _read_typed_metric(typ, xpath, fd):
return ret


def _format_csv(content, delimiter):
"""Format delimited text to have same column width.
Args:
content (str): The content of a metric.
delimiter (str): Value separator
Returns:
str: Formatted content.
Example:
>>> content = (
"value_mse,deviation_mse,data_set\n"
"0.421601,0.173461,train\n"
"0.67528,0.289545,testing\n"
"0.671502,0.297848,validation\n"
)
>>> _format_csv(content, ",")
"value_mse deviation_mse data_set\n"
"0.421601 0.173461 train\n"
"0.67528 0.289545 testing\n"
"0.671502 0.297848 validation\n"
"""
reader = csv_reader(StringIO(content), delimiter=builtin_str(delimiter))
rows = [row for row in reader]
max_widths = [max(map(len, column)) for column in zip(*rows)]

lines = [
" ".join(
"{entry:{width}}".format(entry=entry, width=width + 2)
for entry, width in zip(row, max_widths)
)
for row in rows
]

return "\n".join(lines)


def _format_output(content, typ):
"""Tabularize the content according to its type.
Args:
content (str): The content of a metric.
typ (str): The type of metric -- (raw|json|tsv|htsv|csv|hcsv).
Returns:
str: Content in a raw or tabular format.
"""

if "csv" in str(typ):
return _format_csv(content, delimiter=",")

if "tsv" in str(typ):
return _format_csv(content, delimiter="\t")

return content


def _read_metric(fd, typ=None, xpath=None, rel_path=None, branch=None):
typ = typ.lower().strip() if typ else typ
try:
if xpath:
return _read_typed_metric(typ, xpath.strip(), fd)
else:
return fd.read().strip()
return _format_output(fd.read().strip(), typ)
# Json path library has to be replaced or wrapped in
# order to fix this too broad except clause.
except Exception:
Expand All @@ -86,6 +146,23 @@ def _read_metric(fd, typ=None, xpath=None, rel_path=None, branch=None):


def _collect_metrics(self, path, recursive, typ, xpath, branch):
"""Gather all the metric outputs.
Args:
path (str): Path to a metric file or a directory.
recursive (bool): If path is a directory, do a recursive search for
metrics on the given path.
typ (str): The type of metric to search for, could be one of the
following (raw|json|tsv|htsv|csv|hcsv).
xpath (str): Path to search for.
branch (str): Branch to look up for metrics.
Returns:
list(tuple): (output, typ, xpath)
- output:
- typ:
- xpath:
"""
outs = [out for stage in self.stages() for out in stage.outs]

if path:
Expand Down Expand Up @@ -126,6 +203,21 @@ def _read_metrics_filesystem(path, typ, xpath, rel_path, branch):


def _read_metrics(self, metrics, branch):
"""Read the content of each metric file and format it.
Args:
metrics (list): List of metric touples
branch (str): Branch to look up for metrics.
Returns:
A dict mapping keys with metrics path name and content.
For example:
{'metric.csv': ("value_mse deviation_mse data_set\n"
"0.421601 0.173461 train\n"
"0.67528 0.289545 testing\n"
"0.671502 0.297848 validation\n")}
"""
res = {}
for out, typ, xpath in metrics:
assert out.scheme == "local"
Expand Down
40 changes: 35 additions & 5 deletions dvc/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,48 @@ def encode(u, encoding=None):
return u.encode(encoding, "replace")


# NOTE: cast_bytes_py2 is taken from
# https://github.com/ipython/ipython_genutils
def csv_reader(unicode_csv_data, dialect=None, **kwargs):
"""csv.reader doesn't support Unicode input, so need to use some tricks
to work around this.
Source: https://docs.python.org/2/library/csv.html#csv-examples
"""
import csv

dialect = dialect or csv.excel

if is_py3:
# Python3 supports encoding by default, so just return the object
for row in csv.reader(unicode_csv_data, dialect=dialect, **kwargs):
yield [cell for cell in row]

else:
# csv.py doesn't do Unicode; encode temporarily as UTF-8:
reader = csv.reader(
utf_8_encoder(unicode_csv_data), dialect=dialect, **kwargs
)
for row in reader:
# decode UTF-8 back to Unicode, cell by cell:
yield [unicode(cell, "utf-8") for cell in row] # noqa: F821


def utf_8_encoder(unicode_csv_data):
"""Source: https://docs.python.org/2/library/csv.html#csv-examples"""
for line in unicode_csv_data:
yield line.encode("utf-8")


def cast_bytes(s, encoding=None):
"""Source: https://github.com/ipython/ipython_genutils"""
if not isinstance(s, bytes):
return encode(s, encoding)
return s


# NOTE _makedirs is taken from
# https://github.com/python/cpython/blob/
# 3ce3dea60646d8a5a1c952469a2eb65f937875b3/Lib/os.py#L196-L226
def _makedirs(name, mode=0o777, exist_ok=False):
"""Source: https://github.com/python/cpython/blob/
3ce3dea60646d8a5a1c952469a2eb65f937875b3/Lib/os.py#L196-L226
"""
head, tail = os.path.split(name)
if not tail:
head, tail = os.path.split(head)
Expand Down
114 changes: 114 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import os
import json

Expand Down Expand Up @@ -167,6 +168,119 @@ def test_xpath_all_with_header(self):
for b in ["foo", "bar", "baz"]:
self.assertSequenceEqual(ret[b]["metric_hcsv"], [[b]])

def test_formatted_output(self):
with open("metrics.csv", "w") as fobj:
# Labels are in Spanish to test unicode characters
fobj.write(
"valor_mse,desviación_mse,data_set\n"
"0.421601,0.173461,entrenamiento\n"
"0.67528,0.289545,pruebas\n"
"0.671502,0.297848,validación\n"
)

with open("metrics.tsv", "w") as fobj:
# Contains quoted newlines to test output correctness
fobj.write(
"value_mse\tdeviation_mse\tdata_set\n"
"0.421601\t0.173461\ttrain\n"
'0.67528\t0.289545\t"test\\ning"\n'
"0.671502\t0.297848\tvalidation\n"
)

with open("metrics.json", "w") as fobj:
fobj.write(
"{\n"
' "data_set": [\n'
' "train",\n'
' "testing",\n'
' "validation"\n'
" ],\n"
' "deviation_mse": [\n'
' "0.173461",\n'
' "0.289545",\n'
' "0.297848"\n'
" ],\n"
' "value_mse": [\n'
' "0.421601",\n'
' "0.67528",\n'
' "0.671502"\n'
" ]\n"
"}"
)

with open("metrics.txt", "w") as fobj:
fobj.write("ROC_AUC: 0.64\nKS: 78.9999999996\nF_SCORE: 77\n")

self.dvc.run(
fname="testing_metrics_output.dvc",
metrics_no_cache=[
"metrics.csv",
"metrics.tsv",
"metrics.json",
"metrics.txt",
],
)

self.dvc.metrics.modify("metrics.csv", typ="csv")
self.dvc.metrics.modify("metrics.tsv", typ="tsv")
self.dvc.metrics.modify("metrics.json", typ="json")

with MockLoggerHandlers(logger.logger):
reset_logger_standard_output()

ret = main(["metrics", "show"])
self.assertEqual(ret, 0)

expected_csv = (
u"\tmetrics.csv:\n"
u"\t\tvalor_mse desviación_mse data_set \n"
u"\t\t0.421601 0.173461 entrenamiento \n"
u"\t\t0.67528 0.289545 pruebas \n"
u"\t\t0.671502 0.297848 validación"
)

expected_tsv = (
"\tmetrics.tsv:\n"
"\t\tvalue_mse deviation_mse data_set \n"
"\t\t0.421601 0.173461 train \n"
"\t\t0.67528 0.289545 test\\ning \n"
"\t\t0.671502 0.297848 validation"
)

expected_txt = (
"\tmetrics.txt:\n"
"\t\tROC_AUC: 0.64\n"
"\t\tKS: 78.9999999996\n"
"\t\tF_SCORE: 77\n"
)

expected_json = (
"\tmetrics.json:\n"
"\t\t{\n"
'\t\t "data_set": [\n'
'\t\t "train",\n'
'\t\t "testing",\n'
'\t\t "validation"\n'
"\t\t ],\n"
'\t\t "deviation_mse": [\n'
'\t\t "0.173461",\n'
'\t\t "0.289545",\n'
'\t\t "0.297848"\n'
"\t\t ],\n"
'\t\t "value_mse": [\n'
'\t\t "0.421601",\n'
'\t\t "0.67528",\n'
'\t\t "0.671502"\n'
"\t\t ]\n"
"\t\t}"
)

stdout = logger.logger.handlers[0].stream.getvalue()
self.assertIn(expected_tsv, stdout)
self.assertIn(expected_csv, stdout)
self.assertIn(expected_txt, stdout)
self.assertIn(expected_json, stdout)


class TestMetricsRecursive(TestDvc):
def setUp(self):
Expand Down

0 comments on commit b3addbd

Please sign in to comment.