From 298770efe98c3457bd0be834c258ffb89f4b18d5 Mon Sep 17 00:00:00 2001 From: "Mr. Outis" <mroutis@protonmail.com> Date: Wed, 27 Mar 2019 12:26:17 -0700 Subject: [PATCH] metrics: shows formatted multiline metrics Close #1716 --- dvc/command/metrics.py | 18 ++++++- dvc/repo/metrics/show.py | 96 ++++++++++++++++++++++++++++++++- dvc/utils/compat.py | 40 ++++++++++++-- tests/test_metrics.py | 114 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 259 insertions(+), 9 deletions(-) diff --git a/dvc/command/metrics.py b/dvc/command/metrics.py index 6da22b7c8f..521d6cc13f 100644 --- a/dvc/command/metrics.py +++ b/dvc/command/metrics.py @@ -6,12 +6,26 @@ def show_metrics(metrics, all_branches=False, all_tags=False): + """ + Args: + metrics (list): Where each element is either a `list` + if an xpath was specified, otherwise a `str` + """ for branch, val in metrics.items(): if all_branches or all_tags: - logger.info("{}:".format(branch)) + logger.info("{branch}:".format(branch=branch)) for fname, metric in val.items(): - logger.info("\t{}: {}".format(fname, metric)) + lines = metric if type(metric) is list else metric.splitlines() + + if len(lines) > 1: + logger.info("\t{fname}:".format(fname=fname)) + + for line in lines: + logger.info("\t\t{content}".format(content=line)) + + else: + logger.info("\t{}: {}".format(fname, metric)) class CmdMetricsShow(CmdBase): diff --git a/dvc/repo/metrics/show.py b/dvc/repo/metrics/show.py index c4667bf5a7..e68ecee1a8 100644 --- a/dvc/repo/metrics/show.py +++ b/dvc/repo/metrics/show.py @@ -7,7 +7,7 @@ import dvc.logger as logger from dvc.exceptions import OutputNotFoundError, BadMetricError, NoMetricsError -from dvc.utils.compat import builtin_str, open +from dvc.utils.compat import builtin_str, open, StringIO, csv_reader def _read_metric_json(fd, json_path): @@ -66,13 +66,73 @@ def _read_typed_metric(typ, xpath, fd): return ret +def _format_csv(content, delimiter): + """Format delimited text to have same column width. + + Args: + content (str): The content of a metric. + delimiter (str): Value separator + + Returns: + str: Formatted content. + + Example: + + >>> content = ( + "value_mse,deviation_mse,data_set\n" + "0.421601,0.173461,train\n" + "0.67528,0.289545,testing\n" + "0.671502,0.297848,validation\n" + ) + >>> _format_csv(content, ",") + + "value_mse deviation_mse data_set\n" + "0.421601 0.173461 train\n" + "0.67528 0.289545 testing\n" + "0.671502 0.297848 validation\n" + """ + reader = csv_reader(StringIO(content), delimiter=builtin_str(delimiter)) + rows = [row for row in reader] + max_widths = [max(map(len, column)) for column in zip(*rows)] + + lines = [ + " ".join( + "{entry:{width}}".format(entry=entry, width=width + 2) + for entry, width in zip(row, max_widths) + ) + for row in rows + ] + + return "\n".join(lines) + + +def _format_output(content, typ): + """Tabularize the content according to its type. + + Args: + content (str): The content of a metric. + typ (str): The type of metric -- (raw|json|tsv|htsv|csv|hcsv). + + Returns: + str: Content in a raw or tabular format. + """ + + if "csv" in str(typ): + return _format_csv(content, delimiter=",") + + if "tsv" in str(typ): + return _format_csv(content, delimiter="\t") + + return content + + def _read_metric(fd, typ=None, xpath=None, rel_path=None, branch=None): typ = typ.lower().strip() if typ else typ try: if xpath: return _read_typed_metric(typ, xpath.strip(), fd) else: - return fd.read().strip() + return _format_output(fd.read().strip(), typ) # Json path library has to be replaced or wrapped in # order to fix this too broad except clause. except Exception: @@ -86,6 +146,23 @@ def _read_metric(fd, typ=None, xpath=None, rel_path=None, branch=None): def _collect_metrics(self, path, recursive, typ, xpath, branch): + """Gather all the metric outputs. + + Args: + path (str): Path to a metric file or a directory. + recursive (bool): If path is a directory, do a recursive search for + metrics on the given path. + typ (str): The type of metric to search for, could be one of the + following (raw|json|tsv|htsv|csv|hcsv). + xpath (str): Path to search for. + branch (str): Branch to look up for metrics. + + Returns: + list(tuple): (output, typ, xpath) + - output: + - typ: + - xpath: + """ outs = [out for stage in self.stages() for out in stage.outs] if path: @@ -126,6 +203,21 @@ def _read_metrics_filesystem(path, typ, xpath, rel_path, branch): def _read_metrics(self, metrics, branch): + """Read the content of each metric file and format it. + + Args: + metrics (list): List of metric touples + branch (str): Branch to look up for metrics. + + Returns: + A dict mapping keys with metrics path name and content. + For example: + + {'metric.csv': ("value_mse deviation_mse data_set\n" + "0.421601 0.173461 train\n" + "0.67528 0.289545 testing\n" + "0.671502 0.297848 validation\n")} + """ res = {} for out, typ, xpath in metrics: assert out.scheme == "local" diff --git a/dvc/utils/compat.py b/dvc/utils/compat.py index 15d5699c44..bf0803427d 100644 --- a/dvc/utils/compat.py +++ b/dvc/utils/compat.py @@ -27,18 +27,48 @@ def encode(u, encoding=None): return u.encode(encoding, "replace") -# NOTE: cast_bytes_py2 is taken from -# https://github.com/ipython/ipython_genutils +def csv_reader(unicode_csv_data, dialect=None, **kwargs): + """csv.reader doesn't support Unicode input, so need to use some tricks + to work around this. + + Source: https://docs.python.org/2/library/csv.html#csv-examples + """ + import csv + + dialect = dialect or csv.excel + + if is_py3: + # Python3 supports encoding by default, so just return the object + for row in csv.reader(unicode_csv_data, dialect=dialect, **kwargs): + yield [cell for cell in row] + + else: + # csv.py doesn't do Unicode; encode temporarily as UTF-8: + reader = csv.reader( + utf_8_encoder(unicode_csv_data), dialect=dialect, **kwargs + ) + for row in reader: + # decode UTF-8 back to Unicode, cell by cell: + yield [unicode(cell, "utf-8") for cell in row] # noqa: F821 + + +def utf_8_encoder(unicode_csv_data): + """Source: https://docs.python.org/2/library/csv.html#csv-examples""" + for line in unicode_csv_data: + yield line.encode("utf-8") + + def cast_bytes(s, encoding=None): + """Source: https://github.com/ipython/ipython_genutils""" if not isinstance(s, bytes): return encode(s, encoding) return s -# NOTE _makedirs is taken from -# https://github.com/python/cpython/blob/ -# 3ce3dea60646d8a5a1c952469a2eb65f937875b3/Lib/os.py#L196-L226 def _makedirs(name, mode=0o777, exist_ok=False): + """Source: https://github.com/python/cpython/blob/ + 3ce3dea60646d8a5a1c952469a2eb65f937875b3/Lib/os.py#L196-L226 + """ head, tail = os.path.split(name) if not tail: head, tail = os.path.split(head) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 523a723428..a1dfb9d050 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import os import json @@ -167,6 +168,119 @@ def test_xpath_all_with_header(self): for b in ["foo", "bar", "baz"]: self.assertSequenceEqual(ret[b]["metric_hcsv"], [[b]]) + def test_formatted_output(self): + with open("metrics.csv", "w") as fobj: + # Labels are in Spanish to test unicode characters + fobj.write( + "valor_mse,desviación_mse,data_set\n" + "0.421601,0.173461,entrenamiento\n" + "0.67528,0.289545,pruebas\n" + "0.671502,0.297848,validación\n" + ) + + with open("metrics.tsv", "w") as fobj: + # Contains quoted newlines to test output correctness + fobj.write( + "value_mse\tdeviation_mse\tdata_set\n" + "0.421601\t0.173461\ttrain\n" + '0.67528\t0.289545\t"test\\ning"\n' + "0.671502\t0.297848\tvalidation\n" + ) + + with open("metrics.json", "w") as fobj: + fobj.write( + "{\n" + ' "data_set": [\n' + ' "train",\n' + ' "testing",\n' + ' "validation"\n' + " ],\n" + ' "deviation_mse": [\n' + ' "0.173461",\n' + ' "0.289545",\n' + ' "0.297848"\n' + " ],\n" + ' "value_mse": [\n' + ' "0.421601",\n' + ' "0.67528",\n' + ' "0.671502"\n' + " ]\n" + "}" + ) + + with open("metrics.txt", "w") as fobj: + fobj.write("ROC_AUC: 0.64\nKS: 78.9999999996\nF_SCORE: 77\n") + + self.dvc.run( + fname="testing_metrics_output.dvc", + metrics_no_cache=[ + "metrics.csv", + "metrics.tsv", + "metrics.json", + "metrics.txt", + ], + ) + + self.dvc.metrics.modify("metrics.csv", typ="csv") + self.dvc.metrics.modify("metrics.tsv", typ="tsv") + self.dvc.metrics.modify("metrics.json", typ="json") + + with MockLoggerHandlers(logger.logger): + reset_logger_standard_output() + + ret = main(["metrics", "show"]) + self.assertEqual(ret, 0) + + expected_csv = ( + u"\tmetrics.csv:\n" + u"\t\tvalor_mse desviación_mse data_set \n" + u"\t\t0.421601 0.173461 entrenamiento \n" + u"\t\t0.67528 0.289545 pruebas \n" + u"\t\t0.671502 0.297848 validación" + ) + + expected_tsv = ( + "\tmetrics.tsv:\n" + "\t\tvalue_mse deviation_mse data_set \n" + "\t\t0.421601 0.173461 train \n" + "\t\t0.67528 0.289545 test\\ning \n" + "\t\t0.671502 0.297848 validation" + ) + + expected_txt = ( + "\tmetrics.txt:\n" + "\t\tROC_AUC: 0.64\n" + "\t\tKS: 78.9999999996\n" + "\t\tF_SCORE: 77\n" + ) + + expected_json = ( + "\tmetrics.json:\n" + "\t\t{\n" + '\t\t "data_set": [\n' + '\t\t "train",\n' + '\t\t "testing",\n' + '\t\t "validation"\n' + "\t\t ],\n" + '\t\t "deviation_mse": [\n' + '\t\t "0.173461",\n' + '\t\t "0.289545",\n' + '\t\t "0.297848"\n' + "\t\t ],\n" + '\t\t "value_mse": [\n' + '\t\t "0.421601",\n' + '\t\t "0.67528",\n' + '\t\t "0.671502"\n' + "\t\t ]\n" + "\t\t}" + ) + + stdout = logger.logger.handlers[0].stream.getvalue() + self.assertIn(expected_tsv, stdout) + self.assertIn(expected_csv, stdout) + self.assertIn(expected_txt, stdout) + self.assertIn(expected_json, stdout) + class TestMetricsRecursive(TestDvc): def setUp(self):