Skip to content

Commit

Permalink
plots: cleanup data extraction (iterative#6355)
Browse files Browse the repository at this point in the history
* plots: cleanup data extraction

* fixup
  • Loading branch information
pared authored Jul 23, 2021
1 parent 1d7e754 commit 67b42eb
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 89 deletions.
54 changes: 9 additions & 45 deletions dvc/repo/plots/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,10 @@ def __init__(self, path, revision):

def plot_data(filename, revision, content):
_, extension = os.path.splitext(filename.lower())
if extension == ".json":
return JSONPlotData(filename, revision, content)
if extension == ".csv":
return CSVPlotData(filename, revision, content)
if extension == ".tsv":
return CSVPlotData(filename, revision, content, delimiter="\t")
if extension == ".yaml":
return YAMLPlotData(filename, revision, content)
if extension in (".json", ".yaml"):
return DictData(filename, revision, content)
if extension in (".csv", ".tsv"):
return ListData(filename, revision, content)
raise PlotMetricTypeError(filename)


Expand Down Expand Up @@ -68,34 +64,6 @@ def _filter_fields(data_points, filename, revision, fields=None, **kwargs):
return new_data


def _apply_path(data, path=None, **kwargs):
if not path or not isinstance(data, dict):
return data

import jsonpath_ng

found = jsonpath_ng.parse(path).find(data)
first_datum = first(found)
if (
len(found) == 1
and isinstance(first_datum.value, list)
and isinstance(first(first_datum.value), dict)
):
data_points = first_datum.value
elif len(first_datum.path.fields) == 1:
field_name = first(first_datum.path.fields)
data_points = [{field_name: datum.value} for datum in found]
else:
raise PlotDataStructureError()

if not isinstance(data_points, list) or not (
isinstance(first(data_points), dict)
):
raise PlotDataStructureError()

return data_points


def _lists(dictionary):
for _, value in dictionary.items():
if isinstance(value, dict):
Expand Down Expand Up @@ -158,17 +126,13 @@ def to_datapoints(self, **kwargs):
return data


class JSONPlotData(PlotData):
class DictData(PlotData):
# For files usually parsed as dicts: eg JSON, Yaml
def _processors(self):
parent_processors = super()._processors()
return [_apply_path, _find_data] + parent_processors
return [_find_data] + parent_processors


class CSVPlotData(PlotData):
class ListData(PlotData):
# For files parsed as list: CSV, TSV
pass


class YAMLPlotData(PlotData):
def _processors(self):
parent_processors = super()._processors()
return [_find_data] + parent_processors
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def run(self):
"nanotime>=0.5.2",
"pyasn1>=0.4.1",
"voluptuous>=0.11.7",
"jsonpath-ng>=1.5.1",
"requests>=2.22.0",
"grandalf==0.6",
"distro>=1.3.0",
Expand Down
21 changes: 1 addition & 20 deletions tests/func/plots/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
from dvc.main import main
from dvc.path_info import PathInfo
from dvc.repo import Repo
from dvc.repo.plots.data import (
JSONPlotData,
PlotData,
PlotMetricTypeError,
YAMLPlotData,
)
from dvc.repo.plots.data import PlotData, PlotMetricTypeError
from dvc.repo.plots.template import (
BadTemplateError,
NoFieldInDataError,
Expand Down Expand Up @@ -560,20 +555,6 @@ def test_raise_on_wrong_field(tmp_dir, scm, dvc, run_copy_metrics):
dvc.plots.show("metric.json", props={"y": "no_val"})


@pytest.mark.parametrize("data_class", [JSONPlotData, YAMLPlotData])
def test_find_data_in_dict(tmp_dir, data_class):
metric = [{"accuracy": 1, "loss": 2}, {"accuracy": 3, "loss": 4}]
dmetric = {"train": metric}

plot_data = data_class("-", "revision", dmetric)

expected = metric
for d in expected:
d["rev"] = "revision"

assert list(map(dict, plot_data.to_datapoints())) == expected


def test_multiple_plots(tmp_dir, scm, dvc, run_copy_metrics):
metric1 = [
OrderedDict([("first_val", 100), ("second_val", 100), ("val", 2)]),
Expand Down
42 changes: 19 additions & 23 deletions tests/unit/repo/plots/test_data.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,9 @@
from collections import OrderedDict
from typing import Dict, List

import pytest

from dvc.repo.plots.data import _apply_path, _find_data, _lists


@pytest.mark.parametrize(
"path,expected_result",
[
("$.some.path[*].a", [{"a": 1}, {"a": 4}]),
("$.some.path", [{"a": 1, "b": 2, "c": 3}, {"a": 4, "b": 5, "c": 6}]),
],
)
def test_parse_json(path, expected_result):
value = {
"some": {"path": [{"a": 1, "b": 2, "c": 3}, {"a": 4, "b": 5, "c": 6}]}
}

result = _apply_path(value, path=path)

assert result == expected_result
from dvc.repo.plots.data import DictData, _lists


@pytest.mark.parametrize(
Expand All @@ -39,10 +23,22 @@ def test_finding_lists(dictionary, expected_result):
assert list(result) == expected_result


@pytest.mark.parametrize("fields", [{"x"}, set()])
def test_finding_data(fields):
data = {"a": {"b": [{"x": 2, "y": 3}, {"x": 1, "y": 5}]}}
def test_find_data_in_dict(tmp_dir):
m1 = [{"accuracy": 1, "loss": 2}, {"accuracy": 3, "loss": 4}]
m2 = [{"x": 1}, {"x": 2}]
dmetric = OrderedDict([("t1", m1), ("t2", m2)])

plot_data = DictData("-", "revision", dmetric)

def points_with(datapoints: List, additional_info: Dict):
for datapoint in datapoints:
datapoint.update(additional_info)

result = _find_data(data, fields=fields)
return datapoints

assert result == [{"x": 2, "y": 3}, {"x": 1, "y": 5}]
assert list(map(dict, plot_data.to_datapoints())) == points_with(
m1, {"rev": "revision"}
)
assert list(
map(dict, plot_data.to_datapoints(fields={"x"}))
) == points_with(m2, {"rev": "revision"})

0 comments on commit 67b42eb

Please sign in to comment.