Skip to content

Commit

Permalink
Merge pull request iterative#1794 from efiop/dvctags
Browse files Browse the repository at this point in the history
[RFC] introduce `dvc tag`
  • Loading branch information
efiop authored Mar 30, 2019
2 parents c645fdd + 4c5cb7a commit 374b75f
Show file tree
Hide file tree
Showing 22 changed files with 496 additions and 90 deletions.
2 changes: 2 additions & 0 deletions dvc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import dvc.command.pipeline as pipeline
import dvc.command.daemon as daemon
import dvc.command.commit as commit
import dvc.command.tag as tag
from dvc.exceptions import DvcParserError
from dvc import VERSION

Expand Down Expand Up @@ -60,6 +61,7 @@
pipeline,
daemon,
commit,
tag,
]


Expand Down
151 changes: 151 additions & 0 deletions dvc/command/tag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import yaml

import dvc.logger as logger
from dvc.exceptions import DvcException
from dvc.command.base import CmdBase, fix_subparsers, append_doc_link


class CmdTagAdd(CmdBase):
def run(self):
for target in self.args.targets:
try:
self.repo.tag.add(
self.args.tag,
target=target,
with_deps=self.args.with_deps,
recursive=self.args.recursive,
)
except DvcException:
logger.error("failed to add tag")
return 1
return 0


class CmdTagRemove(CmdBase):
def run(self):
for target in self.args.targets:
try:
self.repo.tag.remove(
self.args.tag,
target=target,
with_deps=self.args.with_deps,
recursive=self.args.recursive,
)
except DvcException:
logger.error("failed to remove tag")
return 1
return 0


class CmdTagList(CmdBase):
def run(self):
for target in self.args.targets:
try:
tags = self.repo.tag.list(
target,
with_deps=self.args.with_deps,
recursive=self.args.recursive,
)
if tags:
logger.info(yaml.dump(tags, default_flow_style=False))
except DvcException:
logger.error("failed list tags")
return 1
return 0


def add_parser(subparsers, parent_parser):
TAG_HELP = "A set of commands to manage DVC tags."
tag_parser = subparsers.add_parser(
"tag",
parents=[parent_parser],
description=append_doc_link(TAG_HELP, "tag"),
add_help=False,
)

tag_subparsers = tag_parser.add_subparsers(
dest="cmd",
help="Use DVC tag CMD --help to display command-specific help.",
)

fix_subparsers(tag_subparsers)

TAG_ADD_HELP = "Add DVC tag."
tag_add_parser = tag_subparsers.add_parser(
"add",
parents=[parent_parser],
description=append_doc_link(TAG_ADD_HELP, "tag-add"),
help=TAG_ADD_HELP,
)
tag_add_parser.add_argument("tag", help="Dvc tag.")
tag_add_parser.add_argument(
"targets", nargs="*", default=[None], help="Dvc files."
)
tag_add_parser.add_argument(
"-d",
"--with-deps",
action="store_true",
default=False,
help="Add tag for all dependencies of the specified DVC file.",
)
tag_add_parser.add_argument(
"-R",
"--recursive",
action="store_true",
default=False,
help="Add tag for subdirectories of the specified directory.",
)
tag_add_parser.set_defaults(func=CmdTagAdd)

TAG_REMOVE_HELP = "Remove DVC tag."
tag_remove_parser = tag_subparsers.add_parser(
"remove",
parents=[parent_parser],
description=append_doc_link(TAG_REMOVE_HELP, "tag-remove"),
help=TAG_REMOVE_HELP,
)
tag_remove_parser.add_argument("tag", help="Dvc tag.")
tag_remove_parser.add_argument(
"targets", nargs="*", default=[None], help="Dvc files."
)
tag_remove_parser.add_argument(
"-d",
"--with-deps",
action="store_true",
default=False,
help="Remove tag for all dependencies of the specified DVC file.",
)
tag_remove_parser.add_argument(
"-R",
"--recursive",
action="store_true",
default=False,
help="Remove tag for subdirectories of the specified directory.",
)
tag_remove_parser.set_defaults(func=CmdTagRemove)

TAG_LIST_HELP = "List DVC tags."
tag_list_parser = tag_subparsers.add_parser(
"list",
parents=[parent_parser],
description=append_doc_link(TAG_LIST_HELP, "tag-list"),
help=TAG_LIST_HELP,
)
tag_list_parser.add_argument(
"targets", nargs="*", default=[None], help="Dvc files."
)
tag_list_parser.add_argument(
"-d",
"--with-deps",
action="store_true",
default=False,
help="List tags for all dependencies of the specified DVC file.",
)
tag_list_parser.add_argument(
"-R",
"--recursive",
action="store_true",
default=False,
help="List tags for subdirectories of the specified directory.",
)
tag_list_parser.set_defaults(func=CmdTagList)
7 changes: 7 additions & 0 deletions dvc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,10 @@ def __init__(self, path):
super(TargetNotDirectoryError, self).__init__(
"Target: {} is not a directory".format(path)
)


class CheckoutErrorSuggestGit(DvcException):
def __init__(self, target, cause):
super(CheckoutErrorSuggestGit, self).__init__(
"Did you mean 'git checkout {}'?".format(target), cause=cause
)
57 changes: 41 additions & 16 deletions dvc/output/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,31 @@
"local": OutputLOCAL,
}

SCHEMA = {
OutputBase.PARAM_PATH: str,
# NOTE: currently there are only 3 possible checksum names:
#
# 1) md5 (LOCAL, SSH, GS);
# 2) etag (S3);
# 3) checksum (HDFS);
#
# so when a few types of outputs share the same name, we only need
# specify it once.
# NOTE: currently there are only 3 possible checksum names:
#
# 1) md5 (LOCAL, SSH, GS);
# 2) etag (S3);
# 3) checksum (HDFS);
#
# so when a few types of outputs share the same name, we only need
# specify it once.
CHECKSUM_SCHEMA = {
schema.Optional(RemoteLOCAL.PARAM_CHECKSUM): schema.Or(str, None),
schema.Optional(RemoteS3.PARAM_CHECKSUM): schema.Or(str, None),
schema.Optional(RemoteHDFS.PARAM_CHECKSUM): schema.Or(str, None),
schema.Optional(OutputBase.PARAM_CACHE): bool,
schema.Optional(OutputBase.PARAM_METRIC): OutputBase.METRIC_SCHEMA,
schema.Optional(OutputBase.PARAM_PERSIST): bool,
}

TAGS_SCHEMA = {schema.Optional(str): CHECKSUM_SCHEMA}

def _get(stage, p, info, cache, metric, persist):
SCHEMA = CHECKSUM_SCHEMA.copy()
SCHEMA[OutputBase.PARAM_PATH] = str
SCHEMA[schema.Optional(OutputBase.PARAM_CACHE)] = bool
SCHEMA[schema.Optional(OutputBase.PARAM_METRIC)] = OutputBase.METRIC_SCHEMA
SCHEMA[schema.Optional(OutputBase.PARAM_TAGS)] = TAGS_SCHEMA
SCHEMA[schema.Optional(OutputBase.PARAM_PERSIST)] = bool


def _get(stage, p, info, cache, metric, persist=False, tags=None):
parsed = urlparse(p)
if parsed.scheme == "remote":
name = Config.SECTION_REMOTE_FMT.format(parsed.netloc)
Expand All @@ -66,11 +71,21 @@ def _get(stage, p, info, cache, metric, persist):
remote=remote,
metric=metric,
persist=persist,
tags=tags,
)

for o in OUTS:
if o.supported(p):
return o(stage, p, info, cache=cache, remote=None, metric=metric)
return o(
stage,
p,
info,
cache=cache,
remote=None,
metric=metric,
persist=persist,
tags=tags,
)
return OutputLOCAL(
stage,
p,
Expand All @@ -79,6 +94,7 @@ def _get(stage, p, info, cache, metric, persist):
remote=None,
metric=metric,
persist=persist,
tags=tags,
)


Expand All @@ -89,8 +105,17 @@ def loadd_from(stage, d_list):
cache = d.pop(OutputBase.PARAM_CACHE, True)
metric = d.pop(OutputBase.PARAM_METRIC, False)
persist = d.pop(OutputBase.PARAM_PERSIST, False)
tags = d.pop(OutputBase.PARAM_TAGS, None)
ret.append(
_get(stage, p, info=d, cache=cache, metric=metric, persist=persist)
_get(
stage,
p,
info=d,
cache=cache,
metric=metric,
persist=persist,
tags=tags,
)
)
return ret

Expand Down
19 changes: 17 additions & 2 deletions dvc/output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class OutputBase(object):
},
)

PARAM_TAGS = "tags"

DoesNotExistError = OutputDoesNotExistError
IsNotFileOrDirError = OutputIsNotFileOrDirError

Expand All @@ -59,6 +61,7 @@ def __init__(
cache=True,
metric=False,
persist=False,
tags=None,
):
self.stage = stage
self.repo = stage.repo
Expand All @@ -68,6 +71,7 @@ def __init__(
self.use_cache = False if self.IS_DEPENDENCY else cache
self.metric = False if self.IS_DEPENDENCY else metric
self.persist = persist
self.tags = None if self.IS_DEPENDENCY else (tags or {})

if (
self.use_cache
Expand Down Expand Up @@ -197,6 +201,9 @@ def dumpd(self):
ret[self.PARAM_METRIC] = self.metric
ret[self.PARAM_PERSIST] = self.persist

if self.tags:
ret[self.PARAM_TAGS] = self.tags

return ret

def verify_metric(self):
Expand All @@ -207,12 +214,20 @@ def verify_metric(self):
def download(self, to_info, resume=False):
self.remote.download([self.path_info], [to_info], resume=resume)

def checkout(self, force=False, progress_callback=None):
def checkout(self, force=False, progress_callback=None, tag=None):
if not self.use_cache:
return

if tag:
info = self.tags[tag]
else:
info = self.info

getattr(self.repo.cache, self.scheme).checkout(
output=self, force=force, progress_callback=progress_callback
self.path_info,
info,
force=force,
progress_callback=progress_callback,
)

def remove(self, ignore_remove=False):
Expand Down
2 changes: 2 additions & 0 deletions dvc/output/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
cache=True,
metric=False,
persist=False,
tags=None,
):
super(OutputHDFS, self).__init__(
stage,
Expand All @@ -28,6 +29,7 @@ def __init__(
cache=cache,
metric=metric,
persist=persist,
tags=tags,
)
if remote:
path = posixpath.join(remote.url, urlparse(path).path.lstrip("/"))
Expand Down
9 changes: 3 additions & 6 deletions dvc/output/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
cache=True,
metric=False,
persist=False,
tags=None,
):
super(OutputLOCAL, self).__init__(
stage,
Expand All @@ -31,6 +32,7 @@ def __init__(
cache=cache,
metric=metric,
persist=persist,
tags=tags,
)
if remote:
p = os.path.join(
Expand Down Expand Up @@ -155,12 +157,7 @@ def save(self):

@property
def dir_cache(self):
if self.checksum not in self._dir_cache.keys():
self._dir_cache[
self.checksum
] = self.repo.cache.local.load_dir_cache(self.checksum)

return self._dir_cache[self.checksum]
return self.repo.cache.local.load_dir_cache(self.checksum)

def get_files_number(self):
if self.cache is None:
Expand Down
2 changes: 2 additions & 0 deletions dvc/output/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
cache=True,
metric=False,
persist=False,
tags=None,
):
super(OutputS3, self).__init__(
stage,
Expand All @@ -28,6 +29,7 @@ def __init__(
cache=cache,
metric=metric,
persist=persist,
tags=tags,
)
bucket = remote.bucket if remote else urlparse(path).netloc
path = urlparse(path).path.lstrip("/")
Expand Down
Loading

0 comments on commit 374b75f

Please sign in to comment.