Skip to content

Commit

Permalink
addressing: allow stage name addressing without ":" (iterative#3842)
Browse files Browse the repository at this point in the history
* addressing: allow stage name addressing without colon

* repo: reorganize imports

* refactor/reorganize collect_granular

* add tests for stage addressing

* fix ds and cc issues

* fix unpack error in test

* tests: stage string representations

* exception: add error message

Also remove `cause` when throwing NoOutputOrStage, as
CheckoutErrorSuggestGit will start chaining all the messages
in cause of the exception.

* rename NoOutputOrStage NoOutputOrStageError

Co-authored-by: Ruslan Kuprieiev <[email protected]>
  • Loading branch information
skshetry and efiop authored May 26, 2020
1 parent b479984 commit 2fecc66
Show file tree
Hide file tree
Showing 24 changed files with 564 additions and 173 deletions.
6 changes: 2 additions & 4 deletions dvc/command/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
class CmdPipelineShow(CmdBase):
def _show(self, target, commands, outs, locked):
import networkx
from dvc import dvcfile
from dvc.utils import parse_target

path, name = parse_target(target)
stage = dvcfile.Dvcfile(self.repo, path).stages[name]
stage = self.repo.get_stage(path, name)
G = self.repo.graph
stages = networkx.dfs_postorder_nodes(G, stage)
if locked:
Expand Down Expand Up @@ -58,12 +57,11 @@ def _build_output_graph(G, target_stage):

def _build_graph(self, target, commands=False, outs=False):
import networkx
from dvc import dvcfile
from dvc.repo.graph import get_pipeline
from dvc.utils import parse_target

path, name = parse_target(target)
target_stage = dvcfile.Dvcfile(self.repo, path).stages[name]
target_stage = self.repo.get_stage(path, name)
G = get_pipeline(self.repo.pipelines, target_stage)

nodes = set()
Expand Down
12 changes: 12 additions & 0 deletions dvc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,15 @@ def __init__(self, path_info):

class IsADirectoryError(DvcException):
"""Raised when a file operation is requested on a directory."""


class NoOutputOrStageError(DvcException):
"""
Raised when the target is neither an output nor a stage name in dvc.yaml
"""

def __init__(self, target, file):
super().__init__(
f"'{target}' "
f"does not exist as an output or a stage name in '{file}'"
)
121 changes: 95 additions & 26 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
import logging
import os
from contextlib import contextmanager
from functools import wraps

from funcy import cached_property, cat, first

from dvc.config import Config
from dvc.dvcfile import PIPELINE_FILE, Dvcfile, is_valid_filename
from dvc.exceptions import FileMissingError
from dvc.exceptions import IsADirectoryError as DvcIsADirectoryError
from dvc.exceptions import NotDvcRepoError, OutputNotFoundError
from dvc.exceptions import (
NoOutputOrStageError,
NotDvcRepoError,
OutputNotFoundError,
)
from dvc.ignore import CleanTree
from dvc.path_info import PathInfo
from dvc.repo.tree import RepoTree
from dvc.utils.fs import path_isin

from ..stage.exceptions import StageFileDoesNotExistError, StageNotFound
from ..utils import parse_target
from .graph import check_acyclic, get_pipeline, get_pipelines

logger = logging.getLogger(__name__)


def locked(f):
@wraps(f)
Expand Down Expand Up @@ -181,6 +190,25 @@ def _ignore(self):

self.scm.ignore_list(flist)

def get_stage(self, path=None, name=None):
if not path:
path = PIPELINE_FILE
logger.debug("Assuming '%s' to be a stage inside '%s'", name, path)

dvcfile = Dvcfile(self, path)
return dvcfile.stages[name]

def get_stages(self, path=None, name=None):
if not path:
path = PIPELINE_FILE
logger.debug("Assuming '%s' to be a stage inside '%s'", name, path)

if name:
return [self.get_stage(path, name)]

dvcfile = Dvcfile(self, path)
return list(dvcfile.stages.values())

def check_modified_graph(self, new_stages):
"""Generate graph including the new stage to check for errors"""
# Building graph might be costly for the ones with many DVC-files,
Expand All @@ -204,10 +232,9 @@ def _collect_inside(self, path, graph):
stages = nx.dfs_postorder_nodes(graph)
return [stage for stage in stages if path_isin(stage.path, path)]

def collect(self, target, with_deps=False, recursive=False, graph=None):
import networkx as nx
from ..dvcfile import Dvcfile

def collect(
self, target=None, with_deps=False, recursive=False, graph=None
):
if not target:
return list(graph) if graph else self.stages

Expand All @@ -217,36 +244,81 @@ def collect(self, target, with_deps=False, recursive=False, graph=None):
)

path, name = parse_target(target)
dvcfile = Dvcfile(self, path)
stages = list(dvcfile.stages.filter(name).values())
stages = self.get_stages(path, name)
if not with_deps:
return stages

res = set()
for stage in stages:
pipeline = get_pipeline(get_pipelines(graph or self.graph), stage)
res.update(nx.dfs_postorder_nodes(pipeline, stage))
res.update(self._collect_pipeline(stage, graph=graph))
return res

def collect_granular(self, target, *args, **kwargs):
from ..dvcfile import Dvcfile, is_valid_filename
def _collect_pipeline(self, stage, graph=None):
import networkx as nx

pipeline = get_pipeline(get_pipelines(graph or self.graph), stage)
return nx.dfs_postorder_nodes(pipeline, stage)

def _collect_from_default_dvcfile(self, target):
dvcfile = Dvcfile(self, PIPELINE_FILE)
if dvcfile.exists():
return dvcfile.stages.get(target)

def collect_granular(
self, target=None, with_deps=False, recursive=False, graph=None
):
"""
Priority is in the order of following in case of ambiguity:
- .dvc file or .yaml file
- dir if recursive and directory exists
- stage_name
- output file
"""
if not target:
return [(stage, None) for stage in self.stages]

file, name = parse_target(target)
if is_valid_filename(file) and not kwargs.get("with_deps"):
# Optimization: do not collect the graph for a specific .dvc target
stages = Dvcfile(self, file).stages.filter(name)
return [(stage, None) for stage in stages.values()]
stages = []

try:
(out,) = self.find_outs_by_path(file, strict=False)
filter_info = PathInfo(os.path.abspath(file))
return [(out.stage, filter_info)]
except OutputNotFoundError:
stages = self.collect(target, *args, **kwargs)
return [(stage, None) for stage in stages]
# Optimization: do not collect the graph for a specific target
if not file:
# parsing is ambiguous when it does not have a colon
# or if it's not a dvcfile, as it can be a stage name
# in `dvc.yaml` or, an output in a stage.
logger.debug(
"Checking if stage '%s' is in '%s'", target, PIPELINE_FILE
)
if not (recursive and os.path.isdir(target)):
stage = self._collect_from_default_dvcfile(target)
if stage:
stages = (
self._collect_pipeline(stage) if with_deps else [stage]
)
elif not with_deps and is_valid_filename(file):
stages = self.get_stages(file, name)

if not stages:
if not (recursive and os.path.isdir(target)):
try:
(out,) = self.find_outs_by_path(target, strict=False)
filter_info = PathInfo(os.path.abspath(target))
return [(out.stage, filter_info)]
except OutputNotFoundError:
pass

try:
stages = self.collect(target, with_deps, recursive, graph)
except StageFileDoesNotExistError as exc:
# collect() might try to use `target` as a stage name
# and throw error that dvc.yaml does not exist, whereas it
# should say that both stage name and file does not exist.
if file and is_valid_filename(file):
raise
raise NoOutputOrStageError(target, exc.file) from exc
except StageNotFound as exc:
raise NoOutputOrStageError(target, exc.file) from exc

return [(stage, None) for stage in stages]

def used_cache(
self,
Expand Down Expand Up @@ -443,16 +515,13 @@ def plot_templates(self):
return PlotTemplates(self.dvc_dir)

def _collect_stages(self):
from dvc.dvcfile import Dvcfile, is_valid_filename

stages = []
outs = set()

for root, dirs, files in self.tree.walk(self.root_dir):
for file_name in filter(is_valid_filename, files):
path = os.path.join(root, file_name)
stage_loader = Dvcfile(self, path).stages
stages.extend(stage_loader.values())
stages.extend(self.get_stages(path))
outs.update(
out.fspath
for stage in stages
Expand Down
22 changes: 14 additions & 8 deletions dvc/repo/checkout.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging
import os

from dvc.exceptions import CheckoutError, CheckoutErrorSuggestGit
from dvc.exceptions import (
CheckoutError,
CheckoutErrorSuggestGit,
NoOutputOrStageError,
)
from dvc.progress import Tqdm
from dvc.utils import relpath

Expand All @@ -18,11 +22,11 @@ def _get_unused_links(repo):
return repo.state.get_unused_links(used)


def _fspath_dir(path, root):
def _fspath_dir(path):
if not os.path.exists(str(path)):
return str(path)

path = relpath(path, root)
path = relpath(path)
return os.path.join(path, "") if os.path.isdir(path) else path


Expand Down Expand Up @@ -56,7 +60,7 @@ def _checkout(
targets = [None]
unused = _get_unused_links(self)

stats["deleted"] = [_fspath_dir(u, self.root_dir) for u in unused]
stats["deleted"] = [_fspath_dir(u) for u in unused]
self.state.remove_links(unused)

if isinstance(targets, str):
Expand All @@ -70,7 +74,11 @@ def _checkout(
target, with_deps=with_deps, recursive=recursive
)
)
except (StageFileDoesNotExistError, StageFileBadNameError) as exc:
except (
StageFileDoesNotExistError,
StageFileBadNameError,
NoOutputOrStageError,
) as exc:
if not target:
raise
raise CheckoutErrorSuggestGit(target) from exc
Expand All @@ -87,9 +95,7 @@ def _checkout(
filter_info=filter_info,
)
for key, items in result.items():
stats[key].extend(
_fspath_dir(path, self.root_dir) for path in items
)
stats[key].extend(_fspath_dir(path) for path in items)

if stats.get("failed"):
raise CheckoutError(stats["failed"], stats)
Expand Down
1 change: 1 addition & 0 deletions dvc/repo/commit.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ def commit(self, target, with_deps=False, recursive=False, force=False):
stage.commit()

Dvcfile(self, stage.path).dump(stage)
return stages
6 changes: 2 additions & 4 deletions dvc/repo/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@

@locked
def lock(self, target, unlock=False):
from .. import dvcfile
from dvc.utils import parse_target

path, name = parse_target(target)
dvcfile = dvcfile.Dvcfile(self, path)
stage = dvcfile.stages[name]
stage = self.get_stage(path, name)
stage.locked = False if unlock else True
dvcfile.dump(stage, update_pipeline=True)
stage.dvcfile.dump(stage, update_pipeline=True)

return stage
9 changes: 4 additions & 5 deletions dvc/repo/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@

@locked
def remove(self, target, dvc_only=False):
from ..dvcfile import Dvcfile
from ..dvcfile import Dvcfile, is_valid_filename

path, name = parse_target(target)
dvcfile = Dvcfile(self, path)
stages = list(dvcfile.stages.filter(name).values())
stages = self.get_stages(path, name)
for stage in stages:
stage.remove_outs(force=True)

if not dvc_only:
dvcfile.remove()
if path and is_valid_filename(path) and not dvc_only:
Dvcfile(self, path).remove()

return stages
4 changes: 1 addition & 3 deletions dvc/repo/reproduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def reproduce(
all_pipelines=False,
**kwargs
):
from ..dvcfile import Dvcfile
from dvc.utils import parse_target

if not target and not all_pipelines:
Expand All @@ -81,8 +80,7 @@ def reproduce(
if all_pipelines:
pipelines = active_pipelines
else:
dvcfile = Dvcfile(self, path)
stage = dvcfile.stages[name]
stage = self.get_stage(path, name)
pipelines = [get_pipeline(active_pipelines, stage)]

targets = []
Expand Down
Loading

0 comments on commit 2fecc66

Please sign in to comment.