Skip to content

Commit

Permalink
dvcfile: support remote per output (iterative#6486)
Browse files Browse the repository at this point in the history
Related to iterative#2095
  • Loading branch information
efiop authored Aug 27, 2021
1 parent 95ba68d commit 4382b80
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 20 deletions.
43 changes: 30 additions & 13 deletions dvc/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def loadd_from(stage, d_list):
desc = d.pop(Output.PARAM_DESC, False)
isexec = d.pop(Output.PARAM_ISEXEC, False)
live = d.pop(Output.PARAM_LIVE, False)
remote = d.pop(Output.PARAM_REMOTE, None)
ret.append(
_get(
stage,
Expand All @@ -94,6 +95,7 @@ def loadd_from(stage, d_list):
desc=desc,
isexec=isexec,
live=live,
remote=remote,
)
)
return ret
Expand All @@ -109,6 +111,7 @@ def loads_from(
checkpoint=False,
isexec=False,
live=False,
remote=None,
):
return [
_get(
Expand All @@ -122,6 +125,7 @@ def loads_from(
checkpoint=checkpoint,
isexec=isexec,
live=live,
remote=remote,
)
for s in s_list
]
Expand Down Expand Up @@ -162,7 +166,6 @@ def load_from_pipeline(stage, data, typ="outs"):
metric = typ == stage.PARAM_METRICS
plot = typ == stage.PARAM_PLOTS
live = typ == stage.PARAM_LIVE

if live:
# `live` is single object
data = [data]
Expand All @@ -185,6 +188,7 @@ def load_from_pipeline(stage, data, typ="outs"):
Output.PARAM_CACHE,
Output.PARAM_PERSIST,
Output.PARAM_CHECKPOINT,
Output.PARAM_REMOTE,
],
)

Expand Down Expand Up @@ -255,6 +259,7 @@ class Output:
PARAM_LIVE = "live"
PARAM_LIVE_SUMMARY = "summary"
PARAM_LIVE_HTML = "html"
PARAM_REMOTE = "remote"

METRIC_SCHEMA = Any(
None,
Expand Down Expand Up @@ -283,6 +288,7 @@ def __init__(
live=False,
desc=None,
isexec=False,
remote=None,
):
self.repo = stage.repo if stage else None

Expand Down Expand Up @@ -326,7 +332,7 @@ def __init__(
self.obj = None
self.isexec = False if self.IS_DEPENDENCY else isexec

self.def_remote = None
self.remote = remote

def _parse_path(self, fs, path_info):
if fs.scheme != "local":
Expand Down Expand Up @@ -843,7 +849,8 @@ def get_dir_cache(self, **kwargs):
try:
objects.check(self.odb, obj)
except FileNotFoundError:
self.repo.cloud.pull([obj.hash_info], **kwargs)
remote = self.repo.cloud.get_remote_odb(self.remote)
self.repo.cloud.pull([obj.hash_info], odb=remote, **kwargs)

if self.obj:
return self.obj
Expand All @@ -855,9 +862,9 @@ def get_dir_cache(self, **kwargs):

return self.obj

def collect_used_dir_cache(
def _collect_used_dir_cache(
self, remote=None, force=False, jobs=None, filter_info=None
) -> Dict[Optional["ObjectDB"], Set["HashInfo"]]:
) -> Optional["Tree"]:
"""Fetch dir cache and return used object IDs for this out."""

try:
Expand All @@ -878,13 +885,13 @@ def collect_used_dir_cache(
"unable to fully collect used cache"
" without cache for directory '{}'".format(self)
)
return {}
return None

obj = self.get_obj()
if filter_info and filter_info != self.path_info:
prefix = filter_info.relative_to(self.path_info).parts
obj = obj.filter(prefix)
return {None: set(self._named_obj_ids(obj))}
return obj

def get_used_objs(
self, **kwargs
Expand Down Expand Up @@ -917,22 +924,31 @@ def get_used_objs(
return {}

if self.is_dir_checksum:
return self.collect_used_dir_cache(**kwargs)
obj = self._collect_used_dir_cache(**kwargs)
else:
obj = self.get_obj(filter_info=kwargs.get("filter_info"))
if not obj:
obj = self.odb.get(self.hash_info)

obj = self.get_obj(filter_info=kwargs.get("filter_info"))
if not obj:
obj = self.odb.get(self.hash_info)
return {}

if self.remote:
remote = self.repo.cloud.get_remote_odb(name=self.remote)
else:
remote = None

return {None: set(self._named_obj_ids(obj))}
return {remote: self._named_obj_ids(obj)}

def _named_obj_ids(self, obj):
name = str(self)
obj.hash_info.obj_name = name
yield obj.hash_info
oids = {obj.hash_info}
if isinstance(obj, Tree):
for key, entry_obj in obj:
entry_obj.hash_info.obj_name = self.fs.sep.join([name, *key])
yield entry_obj.hash_info
oids.add(entry_obj.hash_info)
return oids

def get_used_external(
self, **kwargs
Expand Down Expand Up @@ -1033,4 +1049,5 @@ def is_plot(self) -> bool:
Output.PARAM_CACHE: bool,
Output.PARAM_METRIC: Output.METRIC_SCHEMA,
Output.PARAM_DESC: str,
Output.PARAM_REMOTE: str,
}
1 change: 1 addition & 0 deletions dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
Output.PARAM_PERSIST: bool,
Output.PARAM_CHECKPOINT: bool,
Output.PARAM_DESC: str,
Output.PARAM_REMOTE: str,
}
}

Expand Down
3 changes: 3 additions & 0 deletions dvc/stage/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
PARAM_PERSIST = Output.PARAM_PERSIST
PARAM_CHECKPOINT = Output.PARAM_CHECKPOINT
PARAM_DESC = Output.PARAM_DESC
PARAM_REMOTE = Output.PARAM_REMOTE

DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE

Expand All @@ -52,6 +53,8 @@ def _get_flags(out):
yield from out.plot.items()
if out.live and isinstance(out.live, dict):
yield from out.live.items()
if out.remote:
yield PARAM_REMOTE, out.remote


def _serialize_out(out):
Expand Down
44 changes: 44 additions & 0 deletions tests/func/test_data_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,3 +609,47 @@ def test_pull_no_00_prefix(tmp_dir, dvc, remote, monkeypatch):
stats = dvc.pull()
assert stats["fetched"] == 2
assert set(stats["added"]) == {"foo", "bar"}


def test_output_remote(tmp_dir, dvc, make_remote):
from dvc.utils.serialize import modify_yaml

make_remote("default", default=True)
make_remote("for_foo", default=False)
make_remote("for_data", default=False)

tmp_dir.dvc_gen("foo", "foo")
tmp_dir.dvc_gen("bar", "bar")
tmp_dir.dvc_gen("data", {"one": "one", "two": "two"})

with modify_yaml("foo.dvc") as d:
d["outs"][0]["remote"] = "for_foo"

with modify_yaml("data.dvc") as d:
d["outs"][0]["remote"] = "for_data"

dvc.push()

default = dvc.cloud.get_remote_odb("default")
for_foo = dvc.cloud.get_remote_odb("for_foo")
for_data = dvc.cloud.get_remote_odb("for_data")

assert set(default.all()) == {"37b51d194a7513e45b56f6524f2d51f2"}
assert set(for_foo.all()) == {"acbd18db4cc2f85cedef654fccc4a4d8"}
assert set(for_data.all()) == {
"f97c5d29941bfb1b2fdab0874906ab82",
"6b18131dc289fd37006705affe961ef8.dir",
"b8a9f715dbb64fd5c56e7783c6820a61",
}

clean(["foo", "bar", "data"], dvc)

dvc.pull()

assert set(dvc.odb.local.all()) == {
"37b51d194a7513e45b56f6524f2d51f2",
"acbd18db4cc2f85cedef654fccc4a4d8",
"f97c5d29941bfb1b2fdab0874906ab82",
"6b18131dc289fd37006705affe961ef8.dir",
"b8a9f715dbb64fd5c56e7783c6820a61",
}
2 changes: 1 addition & 1 deletion tests/func/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def unreliable_upload(self, fobj, to_info, **kwargs):
dvc.push()
remove(dvc.odb.local.cache_dir)

baz.collect_used_dir_cache()
baz._collect_used_dir_cache()
with patch.object(LocalFileSystem, "upload", side_effect=Exception):
with pytest.raises(DownloadError) as download_error_info:
dvc.pull()
Expand Down
20 changes: 19 additions & 1 deletion tests/remotes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from .hdfs import HDFS, hadoop, hdfs, hdfs_server, real_hdfs # noqa: F401
from .http import HTTP, http, http_server # noqa: F401
from .local import Local, local_cloud, local_remote # noqa: F401
from .local import Local, local_cloud, local_remote, make_local # noqa: F401
from .oss import ( # noqa: F401
OSS,
TEST_OSS_REPO_BUCKET,
Expand Down Expand Up @@ -100,6 +100,24 @@ def docker_services(
return Services(executor)


@pytest.fixture
def make_cloud(request):
def _make_cloud(typ):
return request.getfixturevalue(f"make_{typ}")()

return _make_cloud


@pytest.fixture
def make_remote(tmp_dir, dvc, make_cloud):
def _make_remote(name, typ="local", **kwargs):
cloud = make_cloud(typ)
tmp_dir.add_remote(name=name, config=cloud.config, **kwargs)
return cloud

return _make_remote


@pytest.fixture
def remote(tmp_dir, dvc, request):
cloud = request.param
Expand Down
18 changes: 13 additions & 5 deletions tests/remotes/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,19 @@ def get_url():


@pytest.fixture
def local_cloud(make_tmp_dir):
ret = make_tmp_dir("local-cloud")
ret.url = str(ret)
ret.config = {"url": ret.url}
return ret
def make_local(make_tmp_dir):
def _make_local():
ret = make_tmp_dir("local-cloud")
ret.url = str(ret)
ret.config = {"url": ret.url}
return ret

return _make_local


@pytest.fixture
def local_cloud(make_local):
return make_local()


@pytest.fixture
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/output/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ def test_load_remote_files_from_pipeline(dvc):
assert not out.hash_info


def test_load_remote(dvc):
stage = Stage(dvc)
(foo, bar) = output.load_from_pipeline(
stage,
["foo", {"bar": {"remote": "myremote"}}],
)
assert foo.remote is None
assert bar.remote == "myremote"


@pytest.mark.parametrize("typ", [None, "", "illegal"])
def test_load_from_pipeline_error_on_typ(dvc, typ):
with pytest.raises(ValueError):
Expand Down

0 comments on commit 4382b80

Please sign in to comment.