Skip to content

Commit

Permalink
erepo: make repo configuration explicit (iterative#5174)
Browse files Browse the repository at this point in the history
  • Loading branch information
efiop authored Dec 29, 2020
1 parent db555ec commit 6fcafc9
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 67 deletions.
9 changes: 6 additions & 3 deletions dvc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class Config(dict):
CONFIG_LOCAL = "config.local"

def __init__(
self, dvc_dir=None, validate=True, tree=None,
self, dvc_dir=None, validate=True, tree=None, config=None,
): # pylint: disable=super-init-not-called
from dvc.tree.local import LocalTree

Expand All @@ -285,7 +285,7 @@ def __init__(
self.wtree = LocalTree(None, {"url": self.dvc_dir})
self.tree = tree or self.wtree

self.load(validate=validate)
self.load(validate=validate, config=config)

@classmethod
def get_dir(cls, level):
Expand Down Expand Up @@ -325,14 +325,17 @@ def init(dvc_dir):
open(config_file, "w+").close()
return Config(dvc_dir)

def load(self, validate=True):
def load(self, validate=True, config=None):
"""Loads config from all the config files.
Raises:
ConfigError: thrown if config has an invalid format.
"""
conf = self.load_config_to_level()

if config is not None:
merge(conf, config)

if validate:
conf = self.validate(conf)

Expand Down
110 changes: 49 additions & 61 deletions dvc/external_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from funcy import cached_property, reraise, retry, wrap_with

from dvc.cache import Cache
from dvc.config import NoRemoteError, NotDvcRepoError
from dvc.exceptions import (
DvcException,
Expand Down Expand Up @@ -87,6 +86,30 @@ def clean_repos():
_remove(path)


def _get_remote_config(url):
try:
repo = Repo(url)
except NotDvcRepoError:
return {}

try:
name = repo.config["core"].get("remote")
if not name:
# Fill the empty upstream entry with a new remote pointing to the
# original repo's cache location.
name = "auto-generated-upstream"
return {
"core": {"remote": name},
"remote": {name: {"url": repo.cache.local.cache_dir}},
}

# Use original remote to make sure that we are using correct url,
# credential paths, etc if they are relative to the config location.
return {"remote": {name: repo.config["remote"][name]}}
finally:
repo.close()


class ExternalRepo(Repo):
# pylint: disable=no-member

Expand All @@ -102,19 +125,29 @@ def __init__(
uninitialized=False,
**kwargs,
):
super().__init__(
root_dir, scm=scm, rev=rev, uninitialized=uninitialized
)

self.url = url
self.for_write = for_write
self.cache_dir = cache_dir or self._get_cache_dir()
self.cache_types = cache_types

self._setup_cache(self)
self._fix_upstream(self)
self.tree_confs = kwargs

self._cache_config = {
"cache": {
"dir": cache_dir or self._get_cache_dir(),
"type": cache_types,
}
}

config = self._cache_config.copy()
if os.path.isdir(url):
config.update(_get_remote_config(url))

super().__init__(
root_dir,
scm=scm,
rev=rev,
uninitialized=uninitialized,
config=config,
)

def __str__(self):
return self.url

Expand Down Expand Up @@ -218,58 +251,13 @@ def _get_tree_for(self, repo, **kwargs):
kw["fetch"] = True
return RepoTree(repo, **kw)

@staticmethod
def _fix_local_remote(orig_repo, src_repo, remote_name):
# If a remote URL is relative to the source repo,
# it will have changed upon config load and made
# relative to this new repo. Restore the old one here.
new_remote = orig_repo.config["remote"][remote_name]
old_remote = src_repo.config["remote"][remote_name]
if new_remote["url"] != old_remote["url"]:
new_remote["url"] = old_remote["url"]

@staticmethod
def _add_upstream(orig_repo, src_repo):
# Fill the empty upstream entry with a new remote pointing to the
# original repo's cache location.
cache_dir = src_repo.cache.local.cache_dir
orig_repo.config["remote"]["auto-generated-upstream"] = {
"url": cache_dir
}
orig_repo.config["core"]["remote"] = "auto-generated-upstream"

def make_repo(self, path):
repo = Repo(path, scm=self.scm, rev=self.get_rev())

self._setup_cache(repo)
self._fix_upstream(repo)

return repo

def _setup_cache(self, repo):
repo.config["cache"]["dir"] = self.cache_dir
repo.cache = Cache(repo)
if self.cache_types:
repo.cache.local.cache_types = self.cache_types

def _fix_upstream(self, repo):
if not os.path.isdir(self.url):
return

try:
rel_path = os.path.relpath(repo.root_dir, self.root_dir)
src_repo = Repo(PathInfo(self.url) / rel_path)
except NotDvcRepoError:
return

try:
remote_name = repo.config["core"].get("remote")
if remote_name:
self._fix_local_remote(repo, src_repo, remote_name)
else:
self._add_upstream(repo, src_repo)
finally:
src_repo.close()
config = self._cache_config.copy()
if os.path.isdir(self.url):
rel = os.path.relpath(path, self.root_dir)
repo_path = os.path.join(self.url, rel)
config.update(_get_remote_config(repo_path))
return Repo(path, scm=self.scm, rev=self.get_rev(), config=config)

@wrap_with(threading.Lock())
def _get_cache_dir(self):
Expand Down
3 changes: 2 additions & 1 deletion dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
rev=None,
subrepos=False,
uninitialized=False,
config=None,
):
from dvc.cache import Cache
from dvc.data_cloud import DataCloud
Expand All @@ -150,7 +151,7 @@ def __init__(
else:
self.tree = LocalTree(self, {"url": self.root_dir}, **tree_kwargs)

self.config = Config(self.dvc_dir, tree=self.tree)
self.config = Config(self.dvc_dir, tree=self.tree, config=config)
self._scm = scm

# used by RepoTree to determine if it should traverse subrepos
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_external_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ def test_subrepo_is_constructed_properly(
subrepo = spy.return_value

assert repo.url == str(tmp_dir)
assert repo.cache_dir == str(cache_dir)
assert repo._cache_config["cache"]["dir"] == str(cache_dir)
assert repo.cache.local.cache_dir == str(cache_dir)
assert subrepo.cache.local.cache_dir == str(cache_dir)

assert repo.cache_types == ["symlink"]
assert repo._cache_config["cache"]["type"] == ["symlink"]
assert repo.cache.local.cache_types == ["symlink"]
assert subrepo.cache.local.cache_types == ["symlink"]

Expand Down

0 comments on commit 6fcafc9

Please sign in to comment.