Skip to content

Commit

Permalink
dvc: support external repos in api functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Suor committed Jun 25, 2019
1 parent dde50a5 commit 4d248b4
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 18 deletions.
77 changes: 68 additions & 9 deletions dvc/api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,74 @@
from contextlib import contextmanager
import os
import tempfile

try:
from contextlib import _GeneratorContextManager as GCM
except ImportError:
from contextlib import GeneratorContextManager as GCM

from dvc.utils import remove
from dvc.utils.compat import urlparse
from dvc.repo import Repo
from dvc.external_repo import ExternalRepo


def get_url(path, repo_dir=None, remote=None):
"""Returns an url of `path` in default or specified remote"""
repo = Repo(repo_dir)
out, = repo.find_outs_by_path(path)
remote_obj = repo.cloud.get_remote(remote)
return str(remote_obj.checksum_to_path_info(out.checksum))
def get_url(path, repo=None, remote=None):
"""Returns an url of a resource specified by path in repo"""
with _make_repo(repo) as _repo:
abspath = os.path.join(_repo.root_dir, path)
out, = _repo.find_outs_by_path(abspath)
remote_obj = _repo.cloud.get_remote(remote)
return str(remote_obj.checksum_to_path_info(out.checksum))


def open(path, repo_dir=None, remote=None, mode="r", encoding=None):
def open(path, repo=None, remote=None, mode="r", encoding=None):
"""Opens a specified resource as a file descriptor"""
repo = Repo(repo_dir)
return repo.open(path, remote=remote, mode=mode, encoding=encoding)
args = (path,)
kwargs = {
"repo": repo,
"remote": remote,
"mode": mode,
"encoding": encoding,
}
return _OpenContextManager(_open, args, kwargs)


class _OpenContextManager(GCM):
def __init__(self, func, args, kwds):
self.gen = func(*args, **kwds)
self.func, self.args, self.kwds = func, args, kwds

def __getattr__(self, name):
raise AttributeError(
"dvc.api.open() should be used in a with statement"
)


def _open(path, repo=None, remote=None, mode="r", encoding=None):
with _make_repo(repo) as _repo:
abspath = os.path.join(_repo.root_dir, path)
with _repo.open(
abspath, remote=remote, mode=mode, encoding=encoding
) as fd:
yield fd


def read(path, repo=None, remote=None, mode="r", encoding=None):
"""Read a specified resource into string"""
with open(path, repo, remote=remote, mode=mode, encoding=encoding) as fd:
return fd.read()


@contextmanager
def _make_repo(repo_url):
if not repo_url or urlparse(repo_url).scheme == "":
yield Repo(repo_url)
else:
tmp_dir = tempfile.mkdtemp("dvc-repo")
try:
ext_repo = ExternalRepo(tmp_dir, url=repo_url)
ext_repo.install()
yield ext_repo.repo
finally:
remove(tmp_dir)
46 changes: 37 additions & 9 deletions tests/func/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dvc import api
from dvc.main import main
from dvc.path_info import URLInfo
from dvc.remote.config import RemoteConfig
from .test_data_cloud import (
_should_test_aws,
get_aws_url,
Expand Down Expand Up @@ -55,33 +56,40 @@ class HDFS:


@pytest.fixture
def remote(request):
def remote_url(request):
if not request.param.should_test():
raise pytest.skip()
return request.param
return request.param.get_url()


def pytest_generate_tests(metafunc):
if "remote" in metafunc.fixturenames:
metafunc.parametrize("remote", remote_params, indirect=True)
if "remote_url" in metafunc.fixturenames:
metafunc.parametrize("remote_url", remote_params, indirect=True)


def run_dvc(*argv):
assert main(argv) == 0


def test_get_url(repo_dir, dvc_repo, remote):
remote_url = remote.get_url()

def test_get_url(repo_dir, dvc_repo, remote_url):
run_dvc("remote", "add", "-d", "upstream", remote_url)
dvc_repo.add(repo_dir.FOO)

expected_url = URLInfo(remote_url) / "ac/bd18db4cc2f85cedef654fccc4a4d8"
assert api.get_url(repo_dir.FOO) == expected_url


def test_open(repo_dir, dvc_repo, remote):
run_dvc("remote", "add", "-d", "upstream", remote.get_url())
def test_get_url_external(repo_dir, dvc_repo, erepo, remote_url):
_set_remote_url_and_commit(erepo.dvc, remote_url)

# Using file url to force clone to tmp repo
repo_url = "file://" + erepo.dvc.root_dir
expected_url = URLInfo(remote_url) / "ac/bd18db4cc2f85cedef654fccc4a4d8"
assert api.get_url(repo_dir.FOO, repo=repo_url) == expected_url


def test_open(repo_dir, dvc_repo, remote_url):
run_dvc("remote", "add", "-d", "upstream", remote_url)
dvc_repo.add(repo_dir.FOO)
run_dvc("push")

Expand All @@ -90,3 +98,23 @@ def test_open(repo_dir, dvc_repo, remote):

with api.open(repo_dir.FOO) as fd:
assert fd.read() == repo_dir.FOO_CONTENTS


def test_open_external(repo_dir, dvc_repo, erepo, remote_url):
_set_remote_url_and_commit(erepo.dvc, remote_url)
erepo.dvc.push()

# Remove cache to force download
shutil.rmtree(erepo.dvc.cache.local.cache_dir)

# Using file url to force clone to tmp repo
repo_url = "file://" + erepo.dvc.root_dir
with api.open(repo_dir.FOO, repo=repo_url) as fd:
assert fd.read() == repo_dir.FOO_CONTENTS


def _set_remote_url_and_commit(repo, remote_url):
rconfig = RemoteConfig(repo.config)
rconfig.modify("upstream", "url", remote_url)
repo.scm.add([repo.config.config_file])
repo.scm.commit("modify remote")

0 comments on commit 4d248b4

Please sign in to comment.