From 22cc2a5cbc8f9e71c81b300eec13525b006b79e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Redzy=C5=84ski?= Date: Mon, 25 Oct 2021 21:41:13 +0200 Subject: [PATCH] plots: fix resolve template path bug Fix: #6854 --- dvc/repo/plots/template.py | 10 +++++----- tests/unit/render/test_vega.py | 26 ++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/dvc/repo/plots/template.py b/dvc/repo/plots/template.py index 1e1384dd9f..4e8cab56fc 100644 --- a/dvc/repo/plots/template.py +++ b/dvc/repo/plots/template.py @@ -109,16 +109,16 @@ def templates_dir(self): @staticmethod def _find(templates, template_name): for template in templates: - if ( - template_name == template - or template_name + ".json" == template + if template.endswith(template_name) or template.endswith( + template_name + ".json" ): return template return None def _find_in_project(self, name: str) -> Optional["StrPath"]: - if os.path.exists(name): - return name + full_path = os.path.abspath(name) + if os.path.exists(full_path): + return full_path if os.path.exists(self.templates_dir): templates = [ diff --git a/tests/unit/render/test_vega.py b/tests/unit/render/test_vega.py index 0b3d936c67..5ff80b55c7 100644 --- a/tests/unit/render/test_vega.py +++ b/tests/unit/render/test_vega.py @@ -451,3 +451,29 @@ def test_find_vega(tmp_dir, dvc): first(plot_content["layer"])["encoding"]["x"]["field"] == INDEX_FIELD ) assert first(plot_content["layer"])["encoding"]["y"]["field"] == "y" + + +@pytest.mark.parametrize( + "template_path, target_name", + [ + (os.path.join(".dvc", "plots", "template.json"), "template"), + (os.path.join(".dvc", "plots", "template.json"), "template.json"), + ( + os.path.join(".dvc", "plots", "subdir", "template.json"), + os.path.join("subdir", "template.json"), + ), + ( + os.path.join(".dvc", "plots", "subdir", "template.json"), + os.path.join("subdir", "template"), + ), + ("template.json", "template.json"), + ], +) +def test_should_resolve_template(tmp_dir, dvc, template_path, target_name): + os.makedirs(os.path.abspath(os.path.dirname(template_path)), exist_ok=True) + with open(template_path, "w", encoding="utf-8") as fd: + fd.write("template_content") + + assert dvc.plots.templates._find_in_project( + target_name + ) == os.path.abspath(template_path)