From a3144482ba3c676ecc59c529061e0ae127364a22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Redzy=C5=84ski?= Date: Thu, 18 Nov 2021 01:31:10 +0100 Subject: [PATCH] render: vega: data processing outside of renderers Fixes: #6943 --- dvc/command/live.py | 4 +- dvc/command/plots.py | 5 +- dvc/render/base.py | 3 +- dvc/render/data.py | 194 ++++++++++++++++++++++++ dvc/render/utils.py | 23 ++- dvc/render/vega.py | 150 +++--------------- tests/unit/command/test_plots.py | 20 +-- tests/unit/render/test_data.py | 161 ++++++++++++++++++++ tests/unit/render/test_render.py | 4 +- tests/unit/render/test_vega.py | 156 +++++-------------- tests/unit/repo/plots/test_templates.py | 45 ++++++ 11 files changed, 494 insertions(+), 271 deletions(-) create mode 100644 dvc/render/data.py create mode 100644 tests/unit/render/test_data.py create mode 100644 tests/unit/repo/plots/test_templates.py diff --git a/dvc/command/live.py b/dvc/command/live.py index 2e97d9fb6d..22b070e221 100644 --- a/dvc/command/live.py +++ b/dvc/command/live.py @@ -3,7 +3,6 @@ from dvc.command import completion from dvc.command.base import CmdBase, fix_subparsers -from dvc.render.utils import match_renderers from dvc.ui import ui @@ -11,11 +10,12 @@ class CmdLive(CmdBase): UNINITIALIZED = True def _run(self, target, revs=None): + from dvc.render.utils import match_renderers, render + metrics, plots = self.repo.live.show(target=target, revs=revs) if plots: html_path = Path.cwd() / (self.args.target + "_html") - from dvc.render.utils import render renderers = match_renderers(plots, self.repo.plots.templates) index_path = render(self.repo, renderers, metrics, html_path) diff --git a/dvc/command/plots.py b/dvc/command/plots.py index 97439b533a..f0676e3ebd 100644 --- a/dvc/command/plots.py +++ b/dvc/command/plots.py @@ -7,8 +7,6 @@ from dvc.command import completion from dvc.command.base import CmdBase, append_doc_link, fix_subparsers from dvc.exceptions import DvcException -from dvc.render.utils import match_renderers, render -from dvc.render.vega import VegaRenderer from dvc.ui import ui from dvc.utils import format_link @@ -41,6 +39,9 @@ def _props(self): def run(self): from pathlib import Path + from dvc.render.utils import match_renderers, render + from dvc.render.vega import VegaRenderer + if self.args.show_vega: if not self.args.targets: logger.error("please specify a target for `--show-vega`") diff --git a/dvc/render/base.py b/dvc/render/base.py index 75f25d09c0..d321384a20 100644 --- a/dvc/render/base.py +++ b/dvc/render/base.py @@ -19,9 +19,8 @@ class Renderer(abc.ABC): REVISIONS_KEY = "revisions" TYPE_KEY = "type" - def __init__(self, data: Dict, templates=None): + def __init__(self, data: Dict, **kwargs): self.data = data - self.templates = templates from dvc.render.utils import get_files diff --git a/dvc/render/data.py b/dvc/render/data.py new file mode 100644 index 0000000000..5f45720ffb --- /dev/null +++ b/dvc/render/data.py @@ -0,0 +1,194 @@ +from copy import deepcopy +from functools import partial +from typing import Dict, List, Optional, Set, Union + +from funcy import first, project + +from dvc.exceptions import DvcException +from dvc.render.base import INDEX_FIELD, REVISION_FIELD + + +class FieldsNotFoundError(DvcException): + def __init__(self, expected_fields, found_fields): + expected_str = ", ".join(expected_fields) + found_str = ", ".join(found_fields) + super().__init__( + f"Could not find all provided fields ('{expected_str}') " + f"in data fields ('{found_str}')." + ) + + +class PlotDataStructureError(DvcException): + def __init__(self): + super().__init__( + "Plot data extraction failed. Please see " + "https://man.dvc.org/plots for supported data formats." + ) + + +def _filter_fields(datapoints: List[Dict], fields: Set) -> List[Dict]: + if not fields: + return datapoints + assert isinstance(fields, set) + + new_data = [] + for data_point in datapoints: + keys = set(data_point.keys()) + if not fields <= keys: + raise FieldsNotFoundError(fields, keys) + + new_data.append(project(data_point, fields)) + + return new_data + + +def _lists(dictionary: Dict): + for _, value in dictionary.items(): + if isinstance(value, dict): + yield from _lists(value) + elif isinstance(value, list): + yield value + + +def _find_first_list(data: Union[Dict, List], fields: Set) -> List[Dict]: + fields = fields or set() + + if not isinstance(data, dict): + return data + + for lst in _lists(data): + if ( + all(isinstance(dp, dict) for dp in lst) + # if fields is empty, it will match any set + and set(first(lst).keys()) & fields == fields + ): + return lst + + raise PlotDataStructureError() + + +def _append_index(datapoints: List[Dict]) -> List[Dict]: + if INDEX_FIELD in first(datapoints).keys(): + return datapoints + + for index, data_point in enumerate(datapoints): + data_point[INDEX_FIELD] = index + return datapoints + + +class Converter: + """ + Class that takes care of converting unspecified data blob + (Dict or List[Dict]) into datapoints (List[Dict]). + If some properties that are required by Template class are missing + ('x', 'y') it will attempt to fill in the blanks. + """ + + @staticmethod + def update(datapoints: List[Dict], update_dict: Dict): + for data_point in datapoints: + data_point.update(update_dict) + return datapoints + + def __init__(self, plot_properties: Optional[Dict] = None): + plot_properties = plot_properties or {} + self.props = deepcopy(plot_properties) + self.inferred_props: Dict = {} + + self.steps = [] + + self._infer_x() + self._infer_fields() + + self.steps.append( + ( + "find_data", + partial( + _find_first_list, + fields=self.inferred_props.get("fields", set()) + - {INDEX_FIELD}, + ), + ) + ) + + if not self.props.get("x", None): + self.steps.append(("append_index", partial(_append_index))) + + self.steps.append( + ( + "filter_fields", + partial( + _filter_fields, + fields=self.inferred_props.get("fields", set()), + ), + ) + ) + + def _infer_x(self): + if not self.props.get("x", None): + self.inferred_props["x"] = INDEX_FIELD + + def skip_step(self, name: str): + self.steps = [(_name, fn) for _name, fn in self.steps if _name != name] + + def _infer_fields(self): + fields = self.props.get("fields", set()) + if fields: + fields = { + *fields, + self.props.get("x", None), + self.props.get("y", None), + self.inferred_props.get("x", None), + } - {None} + self.inferred_props["fields"] = fields + + def _infer_y(self, datapoints: List[Dict]): + if "y" not in self.props: + data_fields = list(first(datapoints)) + skip = ( + REVISION_FIELD, + self.props.get("x", None) or self.inferred_props.get("x"), + ) + inferred_y = first( + f for f in reversed(data_fields) if f not in skip + ) + if "y" in self.inferred_props: + previous_y = self.inferred_props["y"] + if previous_y != inferred_y: + raise DvcException( + f"Inferred y ('{inferred_y}' value does not match" + f"previously matched one ('f{previous_y}')." + ) + else: + self.inferred_props["y"] = inferred_y + + def convert(self, data): + """ + Convert the data. Fill necessary fields ('x', 'y') and return both + generated datapoints and updated properties. + """ + processed = deepcopy(data) + + for _, step in self.steps: + processed = step(processed) + + self._infer_y(processed) + + return processed, {**self.props, **self.inferred_props} + + +def to_datapoints(data: Dict, props: Dict): + converter = Converter(props) + + datapoints = [] + for revision, rev_data in data.items(): + for _, file_data in rev_data.get("data", {}).items(): + if "data" in file_data: + processed, final_props = converter.convert( + file_data.get("data") + ) + + Converter.update(processed, {REVISION_FIELD: revision}) + + datapoints.extend(processed) + return datapoints, final_props diff --git a/dvc/render/utils.py b/dvc/render/utils.py index dd0e2db9bf..a5bdc0ac6d 100644 --- a/dvc/render/utils.py +++ b/dvc/render/utils.py @@ -21,14 +21,31 @@ def group_by_filename(plots_data: Dict) -> List[Dict]: return grouped +def squash_plots_properties(data: Dict) -> Dict: + resolved: Dict[str, str] = {} + for rev_data in data.values(): + for file_data in rev_data.get("data", {}).values(): + props = file_data.get("props", {}) + resolved = {**resolved, **props} + return resolved + + def match_renderers(plots_data, templates): from dvc.render import RENDERERS renderers = [] - for g in group_by_filename(plots_data): + for group in group_by_filename(plots_data): + + plot_properties = squash_plots_properties(group) + template = templates.load(plot_properties.get("template", None)) + for renderer_class in RENDERERS: - if renderer_class.matches(g): - renderers.append(renderer_class(g, templates)) + if renderer_class.matches(group): + renderers.append( + renderer_class( + group, template=template, properties=plot_properties + ) + ) return renderers diff --git a/dvc/render/vega.py b/dvc/render/vega.py index cd898bfbc3..a7ad894bc7 100644 --- a/dvc/render/vega.py +++ b/dvc/render/vega.py @@ -1,92 +1,12 @@ import json import os -from copy import copy, deepcopy -from typing import Dict, List, Optional, Union - -from funcy import first - -from dvc.exceptions import DvcException -from dvc.render.base import ( - INDEX_FIELD, - REVISION_FIELD, - BadTemplateError, - Renderer, -) -from dvc.render.utils import get_files - - -class PlotDataStructureError(DvcException): - def __init__(self): - super().__init__( - "Plot data extraction failed. Please see " - "https://man.dvc.org/plots for supported data formats." - ) - - -def _filter_fields( - datapoints: List[Dict], filename, revision, fields=None -) -> List[Dict]: - if not fields: - return datapoints - assert isinstance(fields, set) - - new_data = [] - for data_point in datapoints: - new_dp = copy(data_point) - - keys = set(data_point.keys()) - if keys & fields != fields: - raise DvcException( - "Could not find fields: '{}' for '{}' at '{}'.".format( - ", ".join(fields), filename, revision - ) - ) - - to_del = keys - fields - for key in to_del: - del new_dp[key] - new_data.append(new_dp) - return new_data - - -def _lists(dictionary): - for _, value in dictionary.items(): - if isinstance(value, dict): - yield from _lists(value) - elif isinstance(value, list): - yield value - +from copy import deepcopy +from typing import Dict, Optional -def _find_data(data: Union[Dict, List], fields=None) -> List[Dict]: - if not isinstance(data, dict): - return data - - if not fields: - # just look for first list of dicts - fields = set() - - for lst in _lists(data): - if ( - all(isinstance(dp, dict) for dp in lst) - and set(first(lst).keys()) & fields == fields - ): - return lst - raise PlotDataStructureError() - - -def _append_index(datapoints: List[Dict], append_index=False) -> List[Dict]: - if not append_index or INDEX_FIELD in first(datapoints).keys(): - return datapoints - - for index, data_point in enumerate(datapoints): - data_point[INDEX_FIELD] = index - return datapoints - - -def _append_revision(datapoints: List[Dict], revision) -> List[Dict]: - for data_point in datapoints: - data_point[REVISION_FIELD] = revision - return datapoints +from dvc.render.base import BadTemplateError, Renderer +from dvc.render.data import to_datapoints +from dvc.render.utils import get_files +from dvc.repo.plots.template import Template class VegaRenderer(Renderer): @@ -107,41 +27,16 @@ class VegaRenderer(Renderer): """ - def _squash_props(self) -> Dict: - resolved: Dict[str, str] = {} - for rev_data in self.data.values(): - for file_data in rev_data.get("data", {}).values(): - props = file_data.get("props", {}) - resolved = {**resolved, **props} - return resolved + def __init__( + self, data: Dict, template: Template, properties: Dict = None + ): + super().__init__(data) + self.properties = properties or {} + self.template = template def _revisions(self): return list(self.data.keys()) - def _datapoints(self, props: Dict): - fields = props.get("fields", set()) - if fields: - fields = {*fields, props.get("x"), props.get("y")} - {None} - - datapoints = [] - for revision, rev_data in self.data.items(): - for filename, file_data in rev_data.get("data", {}).items(): - if "data" in file_data: - tmp = deepcopy(file_data.get("data")) - tmp = _find_data(tmp, fields=fields - {INDEX_FIELD}) - tmp = _append_index( - tmp, append_index=props.get("append_index", False) - ) - tmp = _filter_fields( - tmp, - filename=filename, - revision=revision, - fields=fields, - ) - tmp = _append_revision(tmp, revision=revision) - datapoints.extend(tmp) - return datapoints - def _fill_template(self, template, datapoints, props=None): props = props or {} @@ -172,24 +67,13 @@ def _fill_template(self, template, datapoints, props=None): return content def get_filled_template(self): - props = self._squash_props() - - template = self.templates.load(props.get("template", None)) - - if not props.get("x") and template.has_anchor("x"): - props["append_index"] = True - props["x"] = INDEX_FIELD - - datapoints = self._datapoints(props) + props = self.properties + datapoints, final_props = to_datapoints(self.data, props) if datapoints: - if not props.get("y") and template.has_anchor("y"): - fields = list(first(datapoints)) - skip = (REVISION_FIELD, props.get("x")) - props["y"] = first( - f for f in reversed(fields) if f not in skip - ) - filled_template = self._fill_template(template, datapoints, props) + filled_template = self._fill_template( + self.template, datapoints, final_props + ) return filled_template return None diff --git a/tests/unit/command/test_plots.py b/tests/unit/command/test_plots.py index c581af57f1..62ef36bf0f 100644 --- a/tests/unit/command/test_plots.py +++ b/tests/unit/command/test_plots.py @@ -54,7 +54,7 @@ def test_plots_diff(dvc, mocker, plots_data): cmd = cli_args.func(cli_args) m = mocker.patch("dvc.repo.plots.diff.diff", return_value=plots_data) render_mock = mocker.patch( - "dvc.command.plots.render", return_value="html_path" + "dvc.render.utils.render", return_value="html_path" ) assert cmd.run() == 0 @@ -99,7 +99,7 @@ def test_plots_show_vega(dvc, mocker, plots_data): return_value=plots_data, ) render_mock = mocker.patch( - "dvc.command.plots.render", return_value="html_path" + "dvc.render.utils.render", return_value="html_path" ) assert cmd.run() == 0 @@ -126,10 +126,10 @@ def test_plots_diff_vega(dvc, mocker, capsys, plots_data): cmd = cli_args.func(cli_args) mocker.patch("dvc.repo.plots.diff.diff", return_value=plots_data) mocker.patch( - "dvc.command.plots.VegaRenderer.asdict", + "dvc.render.VegaRenderer.asdict", return_value={"this": "is vega json"}, ) - render_mock = mocker.patch("dvc.command.plots.render") + render_mock = mocker.patch("dvc.render.utils.render") assert cmd.run() == 0 out, _ = capsys.readouterr() @@ -155,7 +155,7 @@ def test_plots_diff_open(tmp_dir, dvc, mocker, capsys, plots_data, auto_open): mocker.patch("dvc.repo.plots.diff.diff", return_value=plots_data) index_path = tmp_dir / "dvc_plots" / "index.html" - mocker.patch("dvc.command.plots.render", return_value=index_path) + mocker.patch("dvc.render.utils.render", return_value=index_path) assert cmd.run() == 0 mocked_open.assert_called_once_with(index_path.as_uri()) @@ -177,7 +177,7 @@ def test_plots_diff_open_WSL(tmp_dir, dvc, mocker, plots_data): mocker.patch("dvc.repo.plots.diff.diff", return_value=plots_data) index_path = tmp_dir / "dvc_plots" / "index.html" - mocker.patch("dvc.command.plots.render", return_value=index_path) + mocker.patch("dvc.render.utils.render", return_value=index_path) assert cmd.run() == 0 mocked_open.assert_called_once_with(str(Path("dvc_plots") / "index.html")) @@ -252,9 +252,9 @@ def test_should_call_render(tmp_dir, mocker, capsys, plots_data, output): output = output or "dvc_plots" index_path = tmp_dir / output / "index.html" renderers = mocker.MagicMock() - mocker.patch("dvc.command.plots.match_renderers", return_value=renderers) + mocker.patch("dvc.render.utils.match_renderers", return_value=renderers) render_mock = mocker.patch( - "dvc.command.plots.render", return_value=index_path + "dvc.render.utils.render", return_value=index_path ) assert cmd.run() == 0 @@ -287,8 +287,8 @@ def test_plots_diff_json(dvc, mocker, capsys): mocker.patch("dvc.repo.plots.diff.diff", return_value=data) renderers = mocker.MagicMock() - mocker.patch("dvc.command.plots.match_renderers", return_value=renderers) - render_mock = mocker.patch("dvc.command.plots.render") + mocker.patch("dvc.render.utils.match_renderers", return_value=renderers) + render_mock = mocker.patch("dvc.render.utils.render") show_json_mock = mocker.patch("dvc.command.plots._show_json") diff --git a/tests/unit/render/test_data.py b/tests/unit/render/test_data.py new file mode 100644 index 0000000000..fc36c61dba --- /dev/null +++ b/tests/unit/render/test_data.py @@ -0,0 +1,161 @@ +from collections import OrderedDict + +import pytest + +from dvc.render.data import ( + Converter, + FieldsNotFoundError, + _filter_fields, + _find_first_list, + _lists, + to_datapoints, +) + + +def test_find_first_list_in_dict(): + m1 = [{"accuracy": 1, "loss": 2}, {"accuracy": 3, "loss": 4}] + m2 = [{"x": 1}, {"x": 2}] + dmetric = OrderedDict([("t1", m1), ("t2", m2)]) + + assert _find_first_list(dmetric, fields=set()) == m1 + assert _find_first_list(dmetric, fields={"x"}) == m2 + + +def test_filter_fields(): + m = [{"accuracy": 1, "loss": 2}, {"accuracy": 3, "loss": 4}] + + assert _filter_fields(m, fields=set()) == m + + expected = [{"accuracy": 1}, {"accuracy": 3}] + assert _filter_fields(m, fields={"accuracy"}) == expected + + with pytest.raises(FieldsNotFoundError): + _filter_fields(m, fields={"bad_field"}) + + +@pytest.mark.parametrize( + "dictionary, expected_result", + [ + ({}, []), + ({"x": ["a", "b", "c"]}, [["a", "b", "c"]]), + ( + OrderedDict([("x", {"y": ["a", "b"]}), ("z", {"w": ["c", "d"]})]), + [["a", "b"], ["c", "d"]], + ), + ], +) +def test_finding_lists(dictionary, expected_result): + result = _lists(dictionary) + + assert list(result) == expected_result + + +@pytest.mark.parametrize( + "input_data,properties,expected_datapoints,expected_properties", + [ + ( + # default x and y + {"metric": [{"v": 1}, {"v": 2}]}, + {}, + [{"v": 1, "step": 0}, {"v": 2, "step": 1}], + {"x": "step", "y": "v"}, + ), + ( + # filter fields + {"metric": [{"v": 1, "v2": 0.1}, {"v": 2, "v2": 0.2}]}, + {"fields": {"v"}}, + [{"v": 1, "step": 0}, {"v": 2, "step": 1}], + { + "x": "step", + "y": "v", + "fields": {"v", "step"}, + }, + ), + ( + # choose x and y + {"metric": [{"v": 1, "v2": 0.1}, {"v": 2, "v2": 0.2}]}, + {"x": "v", "y": "v2"}, + [{"v": 1, "v2": 0.1}, {"v": 2, "v2": 0.2}], + {"x": "v", "y": "v2"}, + ), + ( + # append x and y to filtered fields + { + "metric": [ + {"v": 1, "v2": 0.1, "v3": 0.01, "v4": 0.001}, + {"v": 2, "v2": 0.2, "v3": 0.02, "v4": 0.002}, + ] + }, + {"x": "v3", "y": "v4", "fields": {"v"}}, + [ + {"v": 1, "v3": 0.01, "v4": 0.001}, + {"v": 2, "v3": 0.02, "v4": 0.002}, + ], + {"x": "v3", "y": "v4", "fields": {"v", "v3", "v4"}}, + ), + ( + # find metric in nested structure + { + "some": "noise", + "very": { + "nested": { + "metric": [{"v": 1, "v2": 0.1}, {"v": 2, "v2": 0.2}] + } + }, + }, + {"x": "v", "y": "v2"}, + [{"v": 1, "v2": 0.1}, {"v": 2, "v2": 0.2}], + {"x": "v", "y": "v2"}, + ), + ], +) +def test_convert( + input_data, properties, expected_datapoints, expected_properties +): + converter = Converter(properties) + datapoints, resolved_properties = converter.convert(input_data) + + assert datapoints == expected_datapoints + assert resolved_properties == expected_properties + + +def test_convert_skip_step(): + converter = Converter() + converter.skip_step("append_index") + + datapoints, resolved_properties = converter.convert( + {"a": "b", "metric": [{"v": 1}, {"v": 2}]} + ) + + assert datapoints == [{"v": 1}, {"v": 2}] + assert resolved_properties == {"x": "step", "y": "v"} + + +def test_to_datapoints(): + input_data = { + "revision": { + "data": { + "filename": { + "data": { + "metric": [ + {"v": 1, "v2": 0.1, "v3": 0.01, "v4": 0.001}, + {"v": 2, "v2": 0.2, "v3": 0.02, "v4": 0.002}, + ] + } + } + } + } + } + props = {"fields": {"v"}, "x": "v2", "y": "v3"} + + datapoints, resolved_properties = to_datapoints(input_data, props) + + assert datapoints == [ + {"v": 1, "v2": 0.1, "v3": 0.01, "rev": "revision"}, + {"v": 2, "v2": 0.2, "v3": 0.02, "rev": "revision"}, + ] + assert resolved_properties == { + "fields": {"v", "v2", "v3"}, + "x": "v2", + "y": "v3", + } diff --git a/tests/unit/render/test_render.py b/tests/unit/render/test_render.py index 9003fe5a65..cb3a45b427 100644 --- a/tests/unit/render/test_render.py +++ b/tests/unit/render/test_render.py @@ -69,7 +69,9 @@ def clean(txt: str) -> str: def get_vega_string(data, filename): file_data = dpath.util.search(data, ["*", "*", filename]) - return VegaRenderer(file_data, dvc.plots.templates).partial_html() + return VegaRenderer( + file_data, dvc.plots.templates.load() + ).partial_html() index_content = index_path.read_text() assert clean(get_vega_string(data, "file.json")) in clean(index_content) diff --git a/tests/unit/render/test_vega.py b/tests/unit/render/test_vega.py index 1b5eda96bc..324f20792e 100644 --- a/tests/unit/render/test_vega.py +++ b/tests/unit/render/test_vega.py @@ -5,42 +5,11 @@ import pytest from funcy import first +from dvc.render.base import BadTemplateError +from dvc.render.data import INDEX_FIELD, REVISION_FIELD from dvc.render.utils import group_by_filename -from dvc.render.vega import ( - INDEX_FIELD, - REVISION_FIELD, - BadTemplateError, - VegaRenderer, - _find_data, - _lists, -) -from dvc.repo.plots.template import NoFieldInDataError, TemplateNotFoundError - - -@pytest.mark.parametrize( - "dictionary, expected_result", - [ - ({}, []), - ({"x": ["a", "b", "c"]}, [["a", "b", "c"]]), - ( - OrderedDict([("x", {"y": ["a", "b"]}), ("z", {"w": ["c", "d"]})]), - [["a", "b"], ["c", "d"]], - ), - ], -) -def test_finding_lists(dictionary, expected_result): - result = _lists(dictionary) - - assert list(result) == expected_result - - -def test_find_data_in_dict(tmp_dir): - m1 = [{"accuracy": 1, "loss": 2}, {"accuracy": 3, "loss": 4}] - m2 = [{"x": 1}, {"x": 2}] - dmetric = OrderedDict([("t1", m1), ("t2", m2)]) - - assert _find_data(dmetric) == m1 - assert _find_data(dmetric, fields={"x"}) == m2 +from dvc.render.vega import VegaRenderer +from dvc.repo.plots.template import NoFieldInDataError def test_group_plots_data(): @@ -93,7 +62,7 @@ def test_group_plots_data(): } in results -def test_one_column(tmp_dir, scm, dvc): +def test_one_column(tmp_dir, dvc): props = { "x_label": "x_title", "y_label": "y_title", @@ -101,13 +70,13 @@ def test_one_column(tmp_dir, scm, dvc): } data = { "workspace": { - "data": { - "file.json": {"data": [{"val": 2}, {"val": 3}], "props": props} - } + "data": {"file.json": {"data": [{"val": 2}, {"val": 3}]}} } } - plot_content = VegaRenderer(data, dvc.plots.templates).asdict() + plot_content = VegaRenderer( + data, template=dvc.plots.templates.load(), properties=props + ).asdict() assert plot_content["title"] == "mytitle" assert plot_content["data"]["values"] == [ @@ -132,7 +101,7 @@ def test_multiple_columns(tmp_dir, scm, dvc): "workspace": {"data": {"file.json": {"data": metric, "props": {}}}} } - plot_content = VegaRenderer(data, dvc.plots.templates).asdict() + plot_content = VegaRenderer(data, dvc.plots.templates.load()).asdict() assert plot_content["data"]["values"] == [ { @@ -167,7 +136,9 @@ def test_choose_axes(tmp_dir, scm, dvc): data = { "workspace": {"data": {"file.json": {"data": metric, "props": props}}} } - plot_content = VegaRenderer(data, dvc.plots.templates).asdict() + plot_content = VegaRenderer( + data, template=dvc.plots.templates.load(), properties=props + ).asdict() assert plot_content["data"]["values"] == [ { @@ -198,13 +169,11 @@ def test_confusion(tmp_dir, dvc): ] props = {"template": "confusion", "x": "predicted", "y": "actual"} - data = { - "workspace": { - "data": {"file.json": {"data": confusion_matrix, "props": props}} - } - } + data = {"workspace": {"data": {"file.json": {"data": confusion_matrix}}}} - plot_content = VegaRenderer(data, dvc.plots.templates).asdict() + plot_content = VegaRenderer( + data, template=dvc.plots.templates.load("confusion"), properties=props + ).asdict() assert plot_content["data"]["values"] == [ {"predicted": "B", "actual": "A", REVISION_FIELD: "workspace"}, @@ -241,7 +210,7 @@ def test_multiple_revs_default(tmp_dir, scm, dvc): }, } - plot_content = VegaRenderer(data, dvc.plots.templates).asdict() + plot_content = VegaRenderer(data, dvc.plots.templates.load()).asdict() assert plot_content["data"]["values"] == [ {"y": 5, INDEX_FIELD: 0, REVISION_FIELD: "HEAD"}, @@ -266,7 +235,7 @@ def test_metric_missing(tmp_dir, scm, dvc, caplog): "data": {"file.json": {"error": FileNotFoundError(), "props": {}}} }, } - plot_content = VegaRenderer(data, dvc.plots.templates).asdict() + plot_content = VegaRenderer(data, dvc.plots.templates.load()).asdict() assert plot_content["data"]["values"] == [ {"y": 2, INDEX_FIELD: 0, REVISION_FIELD: "v2"}, @@ -278,50 +247,19 @@ def test_metric_missing(tmp_dir, scm, dvc, caplog): assert first(plot_content["layer"])["encoding"]["y"]["field"] == "y" -def test_custom_template(tmp_dir, scm, dvc, custom_template): - metric = [{"a": 1, "b": 2}, {"a": 2, "b": 3}] - props = {"template": os.fspath(custom_template), "x": "a", "y": "b"} - data = { - "workspace": {"data": {"file.json": {"data": metric, "props": props}}} - } - - plot_content = VegaRenderer(data, dvc.plots.templates).asdict() - - assert plot_content["data"]["values"] == [ - {"a": 1, "b": 2, REVISION_FIELD: "workspace"}, - {"a": 2, "b": 3, REVISION_FIELD: "workspace"}, - ] - assert plot_content["encoding"]["x"]["field"] == "a" - assert plot_content["encoding"]["y"]["field"] == "b" - - -def test_raise_on_no_template(tmp_dir, dvc): - metric = [{"val": 2}, {"val": 3}] - props = {"template": "non_existing_template.json"} - data = { - "workspace": {"data": {"file.json": {"data": metric, "props": props}}} - } - - with pytest.raises(TemplateNotFoundError): - VegaRenderer(data, dvc.plots.templates).asdict() - - def test_bad_template(tmp_dir, dvc): metric = [{"val": 2}, {"val": 3}] - (tmp_dir / "template.json").dump({"a": "b", "c": "d"}) - props = {"template": "template.json"} - data = { - "workspace": {"data": {"file.json": {"data": metric, "props": props}}} - } + data = {"workspace": {"data": {"file.json": {"data": metric}}}} + + from dvc.repo.plots.template import Template with pytest.raises(BadTemplateError): - VegaRenderer(data, dvc.plots.templates).asdict() + VegaRenderer(data, Template("name", "content")).asdict() def test_plot_choose_columns(tmp_dir, scm, dvc, custom_template): metric = [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 3, "c": 4}] props = { - "template": os.fspath(custom_template), "fields": {"b", "c"}, "x": "b", "y": "c", @@ -330,7 +268,11 @@ def test_plot_choose_columns(tmp_dir, scm, dvc, custom_template): "workspace": {"data": {"file.json": {"data": metric, "props": props}}} } - plot_content = VegaRenderer(data, dvc.plots.templates).asdict() + plot_content = VegaRenderer( + data, + template=dvc.plots.templates.load(os.fspath(custom_template)), + properties=props, + ).asdict() assert plot_content["data"]["values"] == [ {"b": 2, "c": 3, REVISION_FIELD: "workspace"}, @@ -340,36 +282,15 @@ def test_plot_choose_columns(tmp_dir, scm, dvc, custom_template): assert plot_content["encoding"]["y"]["field"] == "c" -def test_plot_default_choose_column(tmp_dir, scm, dvc): - metric = [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 3, "c": 4}] - data = { - "workspace": { - "data": {"file.json": {"data": metric, "props": {"fields": {"b"}}}} - } - } - - plot_content = VegaRenderer(data, dvc.plots.templates).asdict() - - assert plot_content["data"]["values"] == [ - {INDEX_FIELD: 0, "b": 2, REVISION_FIELD: "workspace"}, - {INDEX_FIELD: 1, "b": 3, REVISION_FIELD: "workspace"}, - ] - assert ( - first(plot_content["layer"])["encoding"]["x"]["field"] == INDEX_FIELD - ) - assert first(plot_content["layer"])["encoding"]["y"]["field"] == "b" - - def test_raise_on_wrong_field(tmp_dir, scm, dvc): metric = [{"val": 2}, {"val": 3}] - data = { - "workspace": { - "data": {"file.json": {"data": metric, "props": {"x": "no_val"}}} - } - } + props = {"x": "no_val"} + data = {"workspace": {"data": {"file.json": {"data": metric}}}} with pytest.raises(NoFieldInDataError): - VegaRenderer(data, dvc.plots.templates).asdict() + VegaRenderer( + data, template=dvc.plots.templates.load(), properties=props + ).asdict() @pytest.mark.parametrize( @@ -422,13 +343,12 @@ def test_should_resolve_template(tmp_dir, dvc, template_path, target_name): def test_as_json(tmp_dir, scm, dvc): metric = [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 3, "c": 4}] - data = { - "workspace": { - "data": {"file.json": {"data": metric, "props": {"fields": {"b"}}}} - } - } + data = {"workspace": {"data": {"file.json": {"data": metric}}}} + props = {"fields": {"b"}} - renderer = VegaRenderer(data, dvc.plots.templates) + renderer = VegaRenderer( + data, template=dvc.plots.templates.load(), properties=props + ) plot_content = renderer.asdict() plot_as_json = first(json.loads(renderer.as_json())) diff --git a/tests/unit/repo/plots/test_templates.py b/tests/unit/repo/plots/test_templates.py new file mode 100644 index 0000000000..6a54dfa688 --- /dev/null +++ b/tests/unit/repo/plots/test_templates.py @@ -0,0 +1,45 @@ +import os + +import pytest + +from dvc.repo.plots.template import TemplateNotFoundError + + +def test_raise_on_no_template(tmp_dir, dvc): + with pytest.raises(TemplateNotFoundError): + dvc.plots.templates.load("non_existing_template.json") + + +@pytest.mark.parametrize( + "template_path, target_name", + [ + (os.path.join(".dvc", "plots", "template.json"), "template"), + (os.path.join(".dvc", "plots", "template.json"), "template.json"), + ( + os.path.join(".dvc", "plots", "subdir", "template.json"), + os.path.join("subdir", "template.json"), + ), + ( + os.path.join(".dvc", "plots", "subdir", "template.json"), + os.path.join("subdir", "template"), + ), + ("template.json", "template.json"), + ], +) +def test_load_template(tmp_dir, dvc, template_path, target_name): + os.makedirs(os.path.abspath(os.path.dirname(template_path)), exist_ok=True) + with open(template_path, "w", encoding="utf-8") as fd: + fd.write("template_content") + + assert dvc.plots.templates.load(target_name).content == "template_content" + + +def test_load_default_template(tmp_dir, dvc): + with open( + os.path.join(dvc.plots.templates.templates_dir, "linear.json"), + "r", + encoding="utf-8", + ) as fd: + content = fd.read() + + assert dvc.plots.templates.load(None).content == content