Skip to content

Commit

Permalink
odb: treat external repo dependencies as objects (iterative#6109)
Browse files Browse the repository at this point in the history
* objects: add ExternalRepoFile object type

* RepoDependency: implement get_obj (use ExternalRepoFile)

* drop distinction between used external and used objects

* update tests

* remove unneeded fetch from ExternalRepoFile

* wrap erepo errors as an object error

* update cloud/remote functions for new used_objs behavior

* fix update() rev/locked behavior

* update import tests

* remove ExternalRepoFile

* move get_obj and get_used_objs implementations back into repo dependency

* return objects plus associated ODB (remote) in get_used_objs

* odb: add tmp_dir field and support creating Remote from an odb

* return dict mapping odb -> objects in get_used_objs

* use dummy git odb for git imports

* fetch: support pulling from specific odb

* update cloud funcs for get_used_objs usage

* update tests for used_objs behavior

* use fetch in dep.download

* catch circular imports from local fs repos

* add tests for chained and circular imports

* move tmp_dir into odb config

* update return docstring for repo.used_objs
  • Loading branch information
pmrowla authored Jun 17, 2021
1 parent 1a5fa22 commit 1d3524f
Show file tree
Hide file tree
Showing 19 changed files with 411 additions and 198 deletions.
159 changes: 108 additions & 51 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import NamedTuple, Optional
import os
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Set

from voluptuous import Required

from dvc.path_info import PathInfo

from .base import Dependency


class RepoPair(NamedTuple):
url: str
rev: Optional[str] = None
if TYPE_CHECKING:
from dvc.objects.db.base import ObjectDB
from dvc.objects.file import HashFile


class RepoDependency(Dependency):
Expand All @@ -28,6 +29,7 @@ class RepoDependency(Dependency):

def __init__(self, def_repo, stage, *args, **kwargs):
self.def_repo = def_repo
self._staged_objs: Dict[str, "HashFile"] = {}
super().__init__(stage, *args, **kwargs)

def _parse_path(self, fs, path_info):
Expand All @@ -37,37 +39,12 @@ def _parse_path(self, fs, path_info):
def is_in_repo(self):
return False

@property
def repo_pair(self) -> RepoPair:
d = self.def_repo
rev = d.get(self.PARAM_REV_LOCK) or d.get(self.PARAM_REV)
return RepoPair(d[self.PARAM_URL], rev)

def __str__(self):
return "{} ({})".format(self.def_path, self.def_repo[self.PARAM_URL])

def _make_repo(self, *, locked=True, **kwargs):
from dvc.external_repo import external_repo

d = self.def_repo
rev = (d.get("rev_lock") if locked else None) or d.get("rev")
return external_repo(d["url"], rev=rev, **kwargs)

def _get_hash(self, locked=True):
from dvc.objects.stage import stage

with self._make_repo(locked=locked) as repo:
path_info = PathInfo(repo.root_dir) / self.def_path
return stage(
self.repo.odb.local,
path_info,
repo.repo_fs,
self.repo.odb.local.fs.PARAM_CHECKSUM,
).hash_info

def workspace_status(self):
current = self._get_hash(locked=True)
updated = self._get_hash(locked=False)
current = self.get_obj(locked=True).hash_info
updated = self.get_obj(locked=False).hash_info

if current != updated:
return {str(self): "update available"}
Expand All @@ -85,39 +62,28 @@ def dumpd(self):

def download(self, to, jobs=None):
from dvc.checkout import checkout
from dvc.config import NoRemoteError
from dvc.exceptions import NoOutputOrStageError
from dvc.objects import save
from dvc.objects.stage import stage
from dvc.objects.db.git import GitObjectDB
from dvc.repo.fetch import fetch_from_odb

odb = self.repo.odb.local

with self._make_repo(cache_dir=odb.cache_dir) as repo:
if self.def_repo.get(self.PARAM_REV_LOCK) is None:
self.def_repo[self.PARAM_REV_LOCK] = repo.get_rev()
path_info = PathInfo(repo.root_dir) / self.def_path
try:
repo.fetch([path_info.fspath], jobs=jobs, recursive=True)
except (NoOutputOrStageError, NoRemoteError):
pass
obj = stage(
odb, path_info, repo.repo_fs, odb.fs.PARAM_CHECKSUM, jobs=jobs
)
save(odb, obj, jobs=jobs)
for odb, objs in self.get_used_objs().items():
if not isinstance(odb, GitObjectDB):
fetch_from_odb(self.repo, odb, objs, jobs=jobs)

obj = self.get_obj()
save(self.repo.odb.local, obj, jobs=jobs)
checkout(
to.path_info,
to.fs,
obj,
odb,
self.repo.odb.local,
dvcignore=None,
state=self.repo.state,
)

def update(self, rev=None):
if rev:
self.def_repo[self.PARAM_REV] = rev

with self._make_repo(locked=False) as repo:
self.def_repo[self.PARAM_REV_LOCK] = repo.get_rev()

Expand All @@ -126,3 +92,94 @@ def changed_checksum(self):
# origin project url and rev_lock, and it makes RepoDependency
# immutable, hence its impossible for checksum to change.
return False

def get_used_objs(
self, **kwargs
) -> Dict[Optional["ObjectDB"], Set["HashFile"]]:
from dvc.config import NoRemoteError
from dvc.exceptions import NoOutputOrStageError
from dvc.objects.db.git import GitObjectDB
from dvc.objects.stage import stage

local_odb = self.repo.odb.local
locked = kwargs.pop("locked", True)
with self._make_repo(
locked=locked, cache_dir=local_odb.cache_dir
) as repo:
used_objs = defaultdict(set)
rev = repo.get_rev()
if locked and self.def_repo.get(self.PARAM_REV_LOCK) is None:
self.def_repo[self.PARAM_REV_LOCK] = rev

path_info = PathInfo(repo.root_dir) / str(self.def_path)
try:
for odb, objs in repo.used_objs(
[os.fspath(path_info)],
force=True,
jobs=kwargs.get("jobs"),
recursive=True,
).items():
if odb is None:
odb = repo.cloud.get_remote().odb
self._check_circular_import(odb)
used_objs[odb].update(objs)
except (NoRemoteError, NoOutputOrStageError):
pass

staged_obj = stage(
local_odb,
path_info,
repo.repo_fs,
local_odb.fs.PARAM_CHECKSUM,
)
self._staged_objs[rev] = staged_obj
git_odb = GitObjectDB(repo.repo_fs, repo.root_dir)
used_objs[git_odb].add(staged_obj)
return used_objs

def _check_circular_import(self, odb):
from dvc.exceptions import CircularImportError
from dvc.fs.repo import RepoFileSystem

if not odb or not isinstance(odb.fs, RepoFileSystem):
return

self_url = self.repo.url or self.repo.root_dir
if odb.fs.repo_url is not None and odb.fs.repo_url == self_url:
raise CircularImportError(self, odb.fs.repo_url, self_url)

def get_obj(self, filter_info=None, **kwargs):
from dvc.objects.stage import stage

odb = self.repo.odb.local
locked = kwargs.pop("locked", True)
with self._make_repo(locked=locked, cache_dir=odb.cache_dir) as repo:
rev = repo.get_rev()
if locked and self.def_repo.get(self.PARAM_REV_LOCK) is None:
self.def_repo[self.PARAM_REV_LOCK] = rev
obj = self._staged_objs.get(rev)
if obj is not None:
return obj

path_info = PathInfo(repo.root_dir) / str(self.def_path)
obj = stage(
odb,
path_info,
repo.repo_fs,
odb.fs.PARAM_CHECKSUM,
)
self._staged_objs[rev] = obj
return obj

def _make_repo(self, locked=True, **kwargs):
from dvc.external_repo import external_repo

d = self.def_repo
rev = self._get_rev(locked=locked)
return external_repo(d[self.PARAM_URL], rev=rev, **kwargs)

def _get_rev(self, locked=True):
d = self.def_repo
return (d.get(self.PARAM_REV_LOCK) if locked else None) or d.get(
self.PARAM_REV
)
8 changes: 8 additions & 0 deletions dvc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,11 @@ def __init__(self, path_infos):
)
super().__init__(msg)
self.path_infos = path_infos


class CircularImportError(DvcException):
def __init__(self, dep, a, b):
super().__init__(
f"'{dep}' contains invalid circular import. "
f"DVC repo '{a}' already imports from '{b}'."
)
6 changes: 6 additions & 0 deletions dvc/fs/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def __init__(
if hasattr(repo, "dvc_dir"):
self._dvcfss[repo.root_dir] = DvcFileSystem(repo=repo)

@property
def repo_url(self):
if self._main_repo is None:
return None
return self._main_repo.url

def _get_repo(self, path: str) -> Optional["Repo"]:
"""Returns repo that the path falls in, using prefix.
Expand Down
14 changes: 12 additions & 2 deletions dvc/objects/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import TYPE_CHECKING, Optional

from dvc.progress import Tqdm

from .tree import Tree

if TYPE_CHECKING:
from .db.base import ObjectDB
from .file import HashFile

logger = logging.getLogger(__name__)


def save(odb, obj, jobs=None, **kwargs):
def save(
odb: "ObjectDB",
obj: "HashFile",
jobs: Optional[int] = None,
**kwargs,
):
if isinstance(obj, Tree):
with ThreadPoolExecutor(max_workers=jobs) as executor:
for future in Tqdm(
Expand All @@ -18,7 +28,7 @@ def save(odb, obj, jobs=None, **kwargs):
entry.path_info,
entry.fs,
entry.hash_info,
**kwargs
**kwargs,
)
for _, entry in obj
),
Expand Down
17 changes: 17 additions & 0 deletions dvc/objects/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@ def __init__(self, fs, path_info, **config):
self.cache_types = config.get("type") or copy(self.DEFAULT_CACHE_TYPES)
self.cache_type_confirmed = False
self.slow_link_warning = config.get("slow_link_warning", True)
self.tmp_dir = config.get("tmp_dir")

@property
def config(self):
return {
"state": self.state,
"verify": self.verify,
"type": self.cache_types,
"slow_link_warning": self.slow_link_warning,
"tmp_dir": self.tmp_dir,
}

def __eq__(self, other):
return self.fs == other.fs and self.path_info == other.path_info

def __hash__(self):
return hash((self.fs.scheme, self.path_info))

def move(self, from_info, to_info):
self.fs.move(from_info, to_info)
Expand Down
24 changes: 24 additions & 0 deletions dvc/objects/db/git.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import logging

from .base import ObjectDB

logger = logging.getLogger(__name__)


class GitObjectDB(ObjectDB):
"""Dummy read-only ODB for uncached objects in external Git repos."""

def __init__(self, fs, path_info, **config):
from dvc.fs.repo import RepoFileSystem

assert isinstance(fs, RepoFileSystem)
super().__init__(fs, path_info)

def get(self, hash_info):
raise NotImplementedError

def add(self, path_info, fs, hash_info, move=True, **kwargs):
raise NotImplementedError

def gc(self, used, jobs=None):
raise NotImplementedError
Loading

0 comments on commit 1d3524f

Please sign in to comment.