Skip to content

Commit

Permalink
render: Extract to dvc_render
Browse files Browse the repository at this point in the history
Create and use separate package for rendering logic.
Closes iterative#6944
  • Loading branch information
daavoo committed Mar 28, 2022
1 parent b71a0f3 commit ddd6c97
Show file tree
Hide file tree
Showing 43 changed files with 755 additions and 2,726 deletions.
5 changes: 3 additions & 2 deletions dvc/commands/experiments/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,15 @@ def show_experiments(
for x in td.column("Experiment")
]
out = kwargs.get("out") or "dvc_plots"
output_file = os.path.join(out, "index.html")
ui.write(
td.to_parallel_coordinates(
output_path=os.path.abspath(out),
output_path=os.path.abspath(output_file),
color_by=kwargs.get("sort_by") or "Experiment",
)
)
if kwargs.get("open"):
return ui.open_browser(os.path.join(out, "index.html"))
return ui.open_browser(output_file)

else:
td.render(
Expand Down
12 changes: 8 additions & 4 deletions dvc/commands/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,21 @@ class CmdLive(CmdBase):
UNINITIALIZED = True

def _run(self, target, revs=None):
from dvc.render.utils import match_renderers, render
from dvc_render import render_html

from dvc.render.match import match_renderers

metrics, plots = self.repo.live.show(target=target, revs=revs)

if plots:
from pathlib import Path

html_path = Path.cwd() / (self.args.target + "_html")
output = Path.cwd() / (self.args.target + "_html") / "index.html"

renderers = match_renderers(plots, self.repo.plots.templates)
index_path = render(self.repo, renderers, metrics, html_path)
renderers = match_renderers(
plots, templates_dir=self.repo.plots.templates_dir
)
index_path = render_html(renderers, output, metrics)
ui.write(index_path.as_uri())
return 0
return 1
Expand Down
61 changes: 34 additions & 27 deletions dvc/commands/plots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import json
import logging
import os

from funcy import first

Expand All @@ -14,15 +15,11 @@
logger = logging.getLogger(__name__)


def _show_json(renderers, path: None, split=False):
if any(r.needs_output_path for r in renderers) and not path:
raise DvcException("Output path ('-o') is required!")
def _show_json(renderers, split=False):
from dvc.render.convert import to_json

result = {
renderer.filename: json.loads(
renderer.as_json(path=path, fill_data=not split)
)
for renderer in renderers
renderer.name: to_json(renderer, split) for renderer in renderers
}
if result:
ui.write_json(result)
Expand All @@ -42,8 +39,9 @@ def _props(self):
def run(self):
from pathlib import Path

from dvc.render.utils import match_renderers, render
from dvc.render.vega import VegaRenderer
from dvc_render import render_html

from dvc.render.match import match_renderers

if self.args.show_vega:
if not self.args.targets:
Expand Down Expand Up @@ -74,31 +72,40 @@ def run(self):
)

renderers = match_renderers(
plots_data=plots_data, templates=self.repo.plots.templates
plots_data=plots_data, out=self.args.out
)

if self.args.show_vega:
renderer = first(
filter(lambda r: isinstance(r, VegaRenderer), renderers)
)
renderer = first(filter(lambda r: r.TYPE == "vega", renderers))
if renderer:
content = renderer.asdict()
ui.write_json(content)
ui.write_json(json.loads(renderer.partial_html()))
return 0
if self.args.json:
_show_json(renderers, self.args.out, self.args.split)
_show_json(renderers, self.args.split)
return 0

html_template_path = self.args.html_template
if not html_template_path:
html_template_path = self.repo.config.get("plots", {}).get(
"html_template", None
)
if html_template_path and not os.path.isabs(
html_template_path
):
html_template_path = os.path.join(
self.repo.dvc_dir, html_template_path
)

rel: str = self.args.out or "dvc_plots"
path: Path = (Path.cwd() / rel).resolve()
index_path = render(
self.repo,
renderers,
path=path,
html_template_path=self.args.html_template,
output_file: Path = (Path.cwd() / rel).resolve() / "index.html"

render_html(
renderers=renderers,
output_file=output_file,
template_path=html_template_path,
)

ui.write(index_path.as_uri())
ui.write(output_file.as_uri())
auto_open = self.repo.config["plots"].get("auto_open", False)
if self.args.open or auto_open:
if not auto_open:
Expand All @@ -107,7 +114,7 @@ def run(self):
"\n"
"\tdvc config plots.auto_open true"
)
return ui.open_browser(index_path)
return ui.open_browser(output_file)

return 0

Expand Down Expand Up @@ -154,17 +161,17 @@ class CmdPlotsTemplates(CmdBase):
]

def run(self):
import os
from dvc_render.vega_templates import dump_templates

try:
out = (
os.path.join(os.getcwd(), self.args.out)
if self.args.out
else self.repo.plots.templates.templates_dir
else self.repo.plots.templates_dir
)

targets = [self.args.target] if self.args.target else None
self.repo.plots.templates.init(output=out, targets=targets)
dump_templates(output=out, targets=targets)
templates_path = os.path.relpath(out, os.getcwd())
ui.write(f"Templates have been written into '{templates_path}'.")

Expand Down
18 changes: 11 additions & 7 deletions dvc/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,21 @@ def to_csv(self) -> str:

def to_parallel_coordinates(
self, output_path: "StrPath", color_by: str = None
) -> str:
from dvc.render.html import write
from dvc.render.plotly import ParallelCoordinatesRenderer
) -> "StrPath":
from dvc_render.html import render_html
from dvc_render.plotly import ParallelCoordinatesRenderer

index_path = write(
output_path,
render_html(
renderers=[
ParallelCoordinatesRenderer(self, color_by, self._fill_value)
ParallelCoordinatesRenderer(
self.as_dict(),
color_by=color_by,
fill_value=self._fill_value,
)
],
output_file=output_path,
)
return index_path.as_uri()
return output_path

def add_column(self, name: str) -> None:
self._columns[name] = Column([self._fill_value] * len(self))
Expand Down
9 changes: 5 additions & 4 deletions dvc/render/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .image import ImageRenderer
from .vega import VegaRenderer

RENDERERS = [ImageRenderer, VegaRenderer]
INDEX_FIELD = "step"
REVISION_FIELD = "rev"
REVISIONS_KEY = "revisions"
TYPE_KEY = "type"
SRC_FIELD = "src"
77 changes: 0 additions & 77 deletions dvc/render/base.py

This file was deleted.

68 changes: 68 additions & 0 deletions dvc/render/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import json
from collections import defaultdict
from typing import Dict, List, Union

from dvc.render import REVISION_FIELD, REVISIONS_KEY, SRC_FIELD, TYPE_KEY
from dvc.render.image_converter import ImageConverter
from dvc.render.vega_converter import VegaConverter


def get_converter(
renderer_class, props
) -> Union[VegaConverter, ImageConverter]:
from dvc_render import ImageRenderer, VegaRenderer

if renderer_class.TYPE == VegaRenderer.TYPE:
return VegaConverter(props)
if renderer_class.TYPE == ImageRenderer.TYPE:
return ImageConverter(props)

raise ValueError(f"Invalid renderer class {renderer_class}")


def to_datapoints(renderer_class, data: Dict, props: Dict):
converter = get_converter(renderer_class, props)
datapoints: List[Dict] = []
for revision, rev_data in data.items():
for filename, file_data in rev_data.get("data", {}).items():
if "data" in file_data:
processed, final_props = converter.convert(
revision, filename, file_data.get("data")
)
datapoints.extend(processed)
return datapoints, final_props


def _group_by_rev(datapoints):
grouped = defaultdict(list)
for datapoint in datapoints:
rev = datapoint.pop(REVISION_FIELD)
grouped[rev].append(datapoint)
return dict(grouped)


def to_json(renderer, split: bool = False) -> List[Dict]:
if renderer.TYPE == "vega":
grouped = _group_by_rev(renderer.datapoints)
if split:
content = renderer.get_filled_template(skip_anchors=["data"])
else:
content = renderer.get_filled_template()
return [
{
TYPE_KEY: renderer.TYPE,
REVISIONS_KEY: sorted(grouped.keys()),
"content": json.loads(content),
"datapoints": grouped,
}
]
if renderer.TYPE == "image":
return [
{
TYPE_KEY: renderer.TYPE,
REVISIONS_KEY: datapoint.get(REVISION_FIELD),
"url": datapoint.get(SRC_FIELD),
}
for datapoint in renderer.datapoints
]
raise ValueError(f"Invalid renderer: {renderer.TYPE}")
Loading

0 comments on commit ddd6c97

Please sign in to comment.