Skip to content

Commit

Permalink
refactor run; move to stage.utils (iterative#5262)
Browse files Browse the repository at this point in the history
* refactor run; move to stage.utils

* Update dvc/repo/run.py

* fix pylint

* skip on run-cache
  • Loading branch information
skshetry authored Jan 13, 2021
1 parent c840cec commit 257f584
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 110 deletions.
24 changes: 18 additions & 6 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import contextlib
import logging
import os
from typing import TYPE_CHECKING, Any, Union

from voluptuous import MultipleInvalid

Expand All @@ -14,6 +15,7 @@
StageFileFormatError,
StageFileIsNotDvcFileError,
)
from dvc.types import AnyPath
from dvc.utils import relpath
from dvc.utils.collections import apply_diff
from dvc.utils.serialize import (
Expand All @@ -24,6 +26,9 @@
parse_yaml_for_update,
)

if TYPE_CHECKING:
from dvc.repo import Repo

logger = logging.getLogger(__name__)

DVC_FILE = "Dvcfile"
Expand Down Expand Up @@ -381,12 +386,19 @@ def merge(self, ancestor, other):


class Dvcfile:
def __new__(cls, repo, path, **kwargs):
def __new__(cls, repo: "Repo", path: AnyPath, **kwargs: Any):
assert path
assert repo

_, ext = os.path.splitext(path)
if ext in [".yaml", ".yml"]:
return PipelineFile(repo, path, **kwargs)
# fallback to single stage file for better error messages
return SingleStageFile(repo, path, **kwargs)
return make_dvcfile(repo, path, **kwargs)


DVCFile = Union["PipelineFile", "SingleStageFile"]


def make_dvcfile(repo: "Repo", path: AnyPath, **kwargs: Any) -> DVCFile:
_, ext = os.path.splitext(str(path))
if ext in [".yaml", ".yml"]:
return PipelineFile(repo, path, **kwargs)
# fallback to single stage file for better error messages
return SingleStageFile(repo, path, **kwargs)
4 changes: 2 additions & 2 deletions dvc/repo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Iterable, Optional

from dvc.repo import locked
from dvc.utils.cli_parse import loads_params_from_cli
from dvc.utils.cli_parse import loads_params

logger = logging.getLogger(__name__)

Expand All @@ -28,7 +28,7 @@ def run(
return repo.experiments.reproduce_queued(jobs=jobs)

if params:
params = loads_params_from_cli(params)
params = loads_params(params)
return repo.experiments.reproduce_one(
targets=targets, params=params, tmp_dir=tmp_dir, **kwargs
)
104 changes: 16 additions & 88 deletions dvc/repo/run.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,29 @@
import os
from contextlib import suppress
from typing import TYPE_CHECKING

from funcy import concat, first, without

from dvc.utils.cli_parse import parse_params_from_cli
from dvc.utils.collections import chunk_dict
from dvc.exceptions import InvalidArgumentError

from . import locked
from .scm_context import scm_context


def is_valid_name(name: str):
from ..stage import INVALID_STAGENAME_CHARS

return not INVALID_STAGENAME_CHARS & set(name)


def _get_file_path(kwargs):
from dvc.dvcfile import DVC_FILE, DVC_FILE_SUFFIX

out = first(
concat(
kwargs.get("outs", []),
kwargs.get("outs_no_cache", []),
kwargs.get("metrics", []),
kwargs.get("metrics_no_cache", []),
kwargs.get("plots", []),
kwargs.get("plots_no_cache", []),
kwargs.get("outs_persist", []),
kwargs.get("outs_persist_no_cache", []),
kwargs.get("checkpoints", []),
without([kwargs.get("live", None)], None),
)
)

return (
os.path.basename(os.path.normpath(out)) + DVC_FILE_SUFFIX
if out
else DVC_FILE
)


def _check_stage_exists(dvcfile, stage):
from dvc.stage import PipelineStage
from dvc.stage.exceptions import (
DuplicateStageName,
StageFileAlreadyExistsError,
)

if not dvcfile.exists():
return

hint = "Use '--force' to overwrite."
if stage.__class__ != PipelineStage:
raise StageFileAlreadyExistsError(
f"'{stage.relpath}' already exists. {hint}"
)
elif stage.name and stage.name in dvcfile.stages:
raise DuplicateStageName(
f"Stage '{stage.name}' already exists in '{stage.relpath}'. {hint}"
)
if TYPE_CHECKING:
from . import Repo


@locked
@scm_context
def run(self, fname=None, no_exec=False, single_stage=False, **kwargs):
from dvc.dvcfile import PIPELINE_FILE, Dvcfile
from dvc.exceptions import InvalidArgumentError, OutputDuplicationError
from dvc.stage import PipelineStage, Stage, create_stage, restore_meta
from dvc.stage.exceptions import InvalidStageName
def run(
self: "Repo",
fname: str = None,
no_exec: bool = False,
single_stage: bool = False,
**kwargs
):
from dvc.stage.utils import check_graphs, create_stage_from_cli

if not kwargs.get("cmd"):
raise InvalidArgumentError("command is not specified")

stage_cls = PipelineStage
path = PIPELINE_FILE
stage_name = kwargs.get("name")

if stage_name and single_stage:
raise InvalidArgumentError(
"`-n|--name` is incompatible with `--single-stage`"
Expand All @@ -91,32 +38,13 @@ def run(self, fname=None, no_exec=False, single_stage=False, **kwargs):
if not stage_name and not single_stage:
raise InvalidArgumentError("`-n|--name` is required")

if single_stage:
kwargs.pop("name", None)
stage_cls = Stage
path = fname or _get_file_path(kwargs)
else:
if not is_valid_name(stage_name):
raise InvalidStageName

params = chunk_dict(parse_params_from_cli(kwargs.pop("params", [])))
stage = create_stage(
stage_cls, repo=self, path=path, params=params, **kwargs
stage = create_stage_from_cli(
self, single_stage=single_stage, fname=fname, **kwargs
)
restore_meta(stage)
if kwargs.get("run_cache", True) and stage.can_be_skipped:
return None

dvcfile = Dvcfile(self, stage.path)
try:
if kwargs.get("force", True):
with suppress(ValueError):
self.stages.remove(stage)
else:
_check_stage_exists(dvcfile, stage)
self.check_modified_graph([stage])
except OutputDuplicationError as exc:
raise OutputDuplicationError(exc.output, set(exc.stages) - {stage})
check_graphs(self, stage, force=kwargs.get("force", True))

if no_exec:
stage.ignore_outs()
Expand All @@ -126,5 +54,5 @@ def run(self, fname=None, no_exec=False, single_stage=False, **kwargs):
run_cache=kwargs.get("run_cache", True),
)

dvcfile.dump(stage, update_lock=not no_exec)
stage.dump(update_lock=not no_exec)
return stage
16 changes: 11 additions & 5 deletions dvc/stage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import string
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional

from funcy import cached_property, project

Expand Down Expand Up @@ -34,6 +34,9 @@
stage_dump_eq,
)

if TYPE_CHECKING:
from dvc.dvcfile import DVCFile

logger = logging.getLogger(__name__)
# Disallow all punctuation characters except hyphen and underscore
INVALID_STAGENAME_CHARS = set(string.punctuation) - {"_", "-"}
Expand Down Expand Up @@ -166,7 +169,7 @@ def path(self, path: str):
self.__dict__.pop("relpath", None)

@property
def dvcfile(self):
def dvcfile(self) -> "DVCFile":
if self.path and self._dvcfile and self.path == self._dvcfile.path:
return self._dvcfile

Expand All @@ -176,13 +179,13 @@ def dvcfile(self):
"and is detached from dvcfile."
)

from dvc.dvcfile import Dvcfile
from dvc.dvcfile import make_dvcfile

self._dvcfile = Dvcfile(self.repo, self.path)
self._dvcfile = make_dvcfile(self.repo, self.path)
return self._dvcfile

@dvcfile.setter
def dvcfile(self, dvcfile):
def dvcfile(self, dvcfile: "DVCFile") -> None:
self._dvcfile = dvcfile

def __repr__(self):
Expand Down Expand Up @@ -666,6 +669,9 @@ def merge(self, ancestor, other):

self.outs[0].merge(ancestor_out, other.outs[0])

def dump(self, update_lock: bool = False):
self.dvcfile.dump(self, update_lock=update_lock)


class PipelineStage(Stage):
def __init__(self, *args, name=None, **kwargs):
Expand Down
Loading

0 comments on commit 257f584

Please sign in to comment.