Skip to content

Commit

Permalink
objects: use separate staging ODB for staging trees (iterative#6195)
Browse files Browse the repository at this point in the history
* fs: use fsspecwrapper in memfs

* odb: add memfs staging ODB in ODBManager

* odb: add read_only attribute for odb instances

* objects: handle memfs paths when staging trees

* odb: remove git ODB

* output: use staging ODB

* RepoDependency: use staging odb for imports

* dvcfs: handle case where dir cache is unavailable

* fix circular import check

* odb: use temp local ODB for staging

* objects: use staging ODB in stage() and tree load

* objects: add typing for add/check/load/stage

* output/dep: use new staging

* add unit tests for local odb staging

* localfs: account for string paths in move()

* include local in odb.by_scheme

* tests: count dir cache like any other object in test_pull_git_imports

* odb: clear staging on gc()

* objects: add tree.get/filter/insert

* output: use obj.filter/get for granular dir commit

* squash filter logic output

* dependency: don't filter imports at the dep level

* don't init staging ODB for external cache

* use staging ODB when computing tree hash for external outs

* remove unneeded force in test granular commit

* use unique-per-manager memfs URLs for staging

* odb: make staging per-ODBManager

* objects: move state dependent staging into ODBManager

* objects: support checking multiple odbs

- staging dependent checks can be done via ODBManager.check

* checkout: do state lookup before explicit stage() in _changed()

* output: use odbmanager based staging

* RepoDependency: use odbmanager based staging

* update obj staging unit tests

* checkout: remove check(), handle individual file exceptions on link()

* diff: use odbmanager based staging

* objects: add ObjectPermissionError for read/write perm errors

* odb: move staging out of odbmanager

* objects: handle staging ODB in objects.stage()

* objects.tree: don't stage anything in tree methods

* odb: clear staged objects on gc()

* remote: don't filter objs by scheme, objs are already grouped by ODB

* diff: update objects.stage usage

* output/dep: update stage() usage

* update ODB tests

* update func tests

* remove memfs from FS_MAP

* revert unneeded objects.check() changes

* remove unused tree.insert()

* revert unused odbmanager changes

* add separate odb error classes

* stage external outs directly in ODB

* check for obj existence in main ODB before staging

* use .dvc/tmp/staging instead of per-odb staging

* gc: only gc loal staging once
  • Loading branch information
pmrowla authored Jul 9, 2021
1 parent 67c46b9 commit 71bae7e
Show file tree
Hide file tree
Showing 24 changed files with 473 additions and 269 deletions.
67 changes: 33 additions & 34 deletions dvc/checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
logger = logging.getLogger(__name__)


def _changed(path_info, fs, obj, cache):
def _changed(path_info, fs, obj, cache, state=None):
logger.trace("checking if '%s'('%s') has changed.", path_info, obj)

try:
Expand All @@ -32,11 +32,13 @@ def _changed(path_info, fs, obj, cache):
)
return True

try:
actual = stage(cache, path_info, fs, obj.hash_info.name).hash_info
except FileNotFoundError:
logger.debug("'%s' doesn't exist.", path_info)
return True
actual = state.get(path_info, fs) if state else None
if actual is None:
try:
actual = stage(cache, path_info, fs, obj.hash_info.name).hash_info
except FileNotFoundError:
logger.debug("'%s' doesn't exist.", path_info)
return True

if obj.hash_info != actual:
logger.debug(
Expand Down Expand Up @@ -120,9 +122,11 @@ def _try_links(cache, from_info, to_info, link_types):


def _link(cache, from_info, to_info):
assert cache.fs.isfile(from_info)
cache.makedirs(to_info.parent)
_try_links(cache, from_info, to_info, cache.cache_types)
try:
_try_links(cache, from_info, to_info, cache.cache_types)
except FileNotFoundError as exc:
raise CheckoutError([to_info]) from exc


def _cache_is_copy(cache, path_info):
Expand Down Expand Up @@ -163,7 +167,7 @@ def _checkout_file(
modified = False
cache_info = cache.hash_to_path_info(obj.hash_info.value)
if fs.exists(path_info):
if not relink and _changed(path_info, fs, obj, cache):
if not relink and _changed(path_info, fs, obj, cache, state=state):
modified = True
_remove(path_info, fs, cache, force=force)
_link(cache, cache_info, path_info)
Expand Down Expand Up @@ -222,19 +226,23 @@ def _checkout_dir(

logger.debug("Linking directory '%s'.", path_info)

failed = []
for entry_key, entry_obj in obj:
entry_modified = _checkout_file(
path_info.joinpath(*entry_key),
fs,
entry_obj,
cache,
force,
progress_callback,
relink,
state=None,
)
if entry_modified:
modified = True
try:
entry_modified = _checkout_file(
path_info.joinpath(*entry_key),
fs,
entry_obj,
cache,
force,
progress_callback,
relink,
state=None,
)
if entry_modified:
modified = True
except CheckoutError as exc:
failed.extend(exc.target_infos)

modified = (
_remove_redundant_files(
Expand All @@ -243,6 +251,9 @@ def _checkout_dir(
or modified
)

if failed:
raise CheckoutError(failed)

if state:
state.save(path_info, fs, obj.hash_info)

Expand Down Expand Up @@ -310,21 +321,9 @@ def checkout(
_remove(path_info, fs, cache, force=force)
failed = path_info

elif not relink and not _changed(path_info, fs, obj, cache):
elif not relink and not _changed(path_info, fs, obj, cache, state=state):
logger.trace("Data '%s' didn't change.", path_info) # type: ignore
skip = True
else:
try:
check(cache, obj)
except (FileNotFoundError, ObjectFormatError):
if not quiet:
logger.warning(
"Cache '%s' not found. File '%s' won't be created.",
obj.hash_info,
path_info,
)
_remove(path_info, fs, cache, force=force)
failed = path_info

if failed or skip:
if progress_callback and obj:
Expand Down
98 changes: 50 additions & 48 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Set
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple

from funcy import first
from voluptuous import Required

from dvc.path_info import PathInfo
Expand Down Expand Up @@ -62,16 +63,18 @@ def dumpd(self):

def download(self, to, jobs=None):
from dvc.checkout import checkout
from dvc.fs.memory import MemoryFileSystem
from dvc.objects import save
from dvc.objects.db.git import GitObjectDB
from dvc.repo.fetch import fetch_from_odb

for odb, objs in self.get_used_objs().items():
if not isinstance(odb, GitObjectDB):
if isinstance(odb.fs, MemoryFileSystem):
for obj in objs:
save(self.repo.odb.local, obj, jobs=jobs)
else:
fetch_from_odb(self.repo, odb, objs, jobs=jobs)

obj = self.get_obj()
save(self.repo.odb.local, obj, jobs=jobs)
checkout(
to.path_info,
to.fs,
Expand All @@ -96,10 +99,15 @@ def changed_checksum(self):
def get_used_objs(
self, **kwargs
) -> Dict[Optional["ObjectDB"], Set["HashFile"]]:
used, _ = self._get_used_and_obj(**kwargs)
return used

def _get_used_and_obj(
self, obj_only=False, **kwargs
) -> Tuple[Dict[Optional["ObjectDB"], Set["HashFile"]], "HashFile"]:
from dvc.config import NoRemoteError
from dvc.exceptions import NoOutputOrStageError, PathMissingError
from dvc.objects.db.git import GitObjectDB
from dvc.objects.stage import stage
from dvc.objects.stage import get_staging, stage

local_odb = self.repo.odb.local
locked = kwargs.pop("locked", True)
Expand All @@ -112,23 +120,25 @@ def get_used_objs(
self.def_repo[self.PARAM_REV_LOCK] = rev

path_info = PathInfo(repo.root_dir) / str(self.def_path)
try:
for odb, objs in repo.used_objs(
[os.fspath(path_info)],
force=True,
jobs=kwargs.get("jobs"),
recursive=True,
).items():
if odb is None:
odb = repo.cloud.get_remote().odb
self._check_circular_import(odb)
used_objs[odb].update(objs)
except (NoRemoteError, NoOutputOrStageError):
pass
if not obj_only:
try:
for odb, objs in repo.used_objs(
[os.fspath(path_info)],
force=True,
jobs=kwargs.get("jobs"),
recursive=True,
).items():
if odb is None:
odb = repo.cloud.get_remote().odb
odb.read_only = True
self._check_circular_import(objs)
used_objs[odb].update(objs)
except (NoRemoteError, NoOutputOrStageError):
pass

try:
staged_obj = stage(
local_odb,
None,
path_info,
repo.repo_fs,
local_odb.fs.PARAM_CHECKSUM,
Expand All @@ -137,45 +147,37 @@ def get_used_objs(
raise PathMissingError(
self.def_path, self.def_repo[self.PARAM_URL]
) from exc
staging = get_staging()
staging.read_only = True

self._staged_objs[rev] = staged_obj
git_odb = GitObjectDB(repo.repo_fs, repo.root_dir)
used_objs[git_odb].add(staged_obj)
return used_objs
used_objs[staging].add(staged_obj)
return used_objs, staged_obj

def _check_circular_import(self, odb):
def _check_circular_import(self, objs):
from dvc.exceptions import CircularImportError
from dvc.fs.repo import RepoFileSystem
from dvc.objects.tree import Tree

if not odb or not isinstance(odb.fs, RepoFileSystem):
obj = first(objs)
if isinstance(obj, Tree):
_, obj = first(obj)
if not isinstance(obj.fs, RepoFileSystem):
return

self_url = self.repo.url or self.repo.root_dir
if odb.fs.repo_url is not None and odb.fs.repo_url == self_url:
raise CircularImportError(self, odb.fs.repo_url, self_url)
if obj.fs.repo_url is not None and obj.fs.repo_url == self_url:
raise CircularImportError(self, obj.fs.repo_url, self_url)

def get_obj(self, filter_info=None, **kwargs):
from dvc.objects.stage import stage

odb = self.repo.odb.local
locked = kwargs.pop("locked", True)
with self._make_repo(locked=locked, cache_dir=odb.cache_dir) as repo:
rev = repo.get_rev()
if locked and self.def_repo.get(self.PARAM_REV_LOCK) is None:
self.def_repo[self.PARAM_REV_LOCK] = rev
obj = self._staged_objs.get(rev)
if obj is not None:
return obj

path_info = PathInfo(repo.root_dir) / str(self.def_path)
obj = stage(
odb,
path_info,
repo.repo_fs,
odb.fs.PARAM_CHECKSUM,
)
self._staged_objs[rev] = obj
return obj
locked = kwargs.get("locked", True)
rev = self._get_rev(locked=locked)
if rev in self._staged_objs:
return self._staged_objs[rev]
_, obj = self._get_used_and_obj(
obj_only=True, filter_info=filter_info, **kwargs
)
return obj

def _make_repo(self, locked=True, **kwargs):
from dvc.external_repo import external_repo
Expand Down
2 changes: 2 additions & 0 deletions dvc/fs/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def _get_granular_hash(
# NOTE: use string paths here for performance reasons
key = tuple(relpath(path_info, out.path_info).split(os.sep))
out.get_dir_cache(remote=remote)
if out.obj is None:
raise FileNotFoundError
obj = out.obj.trie.get(key)
if obj:
return obj.hash_info
Expand Down
4 changes: 3 additions & 1 deletion dvc/fs/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def stat(self, path):
return os.stat(path)

def move(self, from_info, to_info):
if from_info.scheme != "local" or to_info.scheme != "local":
if (
isinstance(from_info, PathInfo) and from_info.scheme != "local"
) or (isinstance(to_info, PathInfo) and to_info.scheme != "local"):
raise NotImplementedError

self.makedirs(to_info.parent)
Expand Down
49 changes: 25 additions & 24 deletions dvc/fs/memory.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,35 @@
from .base import BaseFileSystem
import threading

from funcy import cached_property, wrap_prop

class MemoryFileSystem(BaseFileSystem):
scheme = "local"
PARAM_CHECKSUM = "md5"

def __init__(self, **kwargs):
from fsspec.implementations.memory import MemoryFileSystem as MemFS
from dvc.path_info import CloudURLInfo
from dvc.scheme import Schemes

super().__init__(**kwargs)
from .fsspec_wrapper import FSSpecWrapper

self.fs = MemFS()

def exists(self, path_info) -> bool:
return self.fs.exists(path_info.fspath)

def open(self, path_info, mode="r", encoding=None, **kwargs):
return self.fs.open(
path_info.fspath, mode=mode, encoding=encoding, **kwargs
)
class MemoryFileSystem(FSSpecWrapper): # pylint:disable=abstract-method
scheme = Schemes.MEMORY
PARAM_CHECKSUM = "md5"
PATH_CLS = CloudURLInfo
TRAVERSE_PREFIX_LEN = 2
DEFAULT_BLOCKSIZE = 4096

def info(self, path_info):
return self.fs.info(path_info.fspath)
def __eq__(self, other):
# NOTE: all fsspec MemoryFileSystem instances are equivalent and use a
# single global store
return isinstance(other, type(self))

def stat(self, path_info):
import os
__hash__ = FSSpecWrapper.__hash__

info = self.fs.info(path_info.fspath)
@wrap_prop(threading.Lock())
@cached_property
def fs(self):
from fsspec.implementations.memory import MemoryFileSystem as MemFS

return os.stat_result((0, 0, 0, 0, 0, 0, info["size"], 0, 0, 0))
return MemFS(**self.fs_args)

def walk_files(self, path_info, **kwargs):
raise NotImplementedError
def open(self, *args, **kwargs):
with super().open(*args, **kwargs) as fobj:
fobj.blocksize = self.DEFAULT_BLOCKSIZE
return fobj
Loading

0 comments on commit 71bae7e

Please sign in to comment.