Skip to content

Commit

Permalink
Make repo lock reentrable and provide a context manager (iterative#4416)
Browse files Browse the repository at this point in the history
  • Loading branch information
Suor authored Sep 8, 2020
1 parent 61bf119 commit 67ddb7b
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 23 deletions.
40 changes: 23 additions & 17 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,30 @@
logger = logging.getLogger(__name__)


@contextmanager
def lock_repo(repo):
# pylint: disable=protected-access
depth = getattr(repo, "_lock_depth", 0)
repo._lock_depth = depth + 1

try:
if depth > 0:
yield
else:
with repo.lock, repo.state:
repo._reset()
yield
# Graph cache is no longer valid after we release the repo.lock
repo._reset()
finally:
repo._lock_depth = depth


def locked(f):
@wraps(f)
def wrapper(repo, *args, **kwargs):
with repo.lock, repo.state:
# pylint: disable=protected-access
repo._reset()
ret = f(repo, *args, **kwargs)
# Our graph cache is no longer valid after we release the repo.lock
repo._reset()
return ret
with lock_repo(repo):
return f(repo, *args, **kwargs)

return wrapper

Expand All @@ -44,11 +58,11 @@ class Repo:

from dvc.repo.add import add
from dvc.repo.brancher import brancher
from dvc.repo.checkout import _checkout
from dvc.repo.checkout import checkout
from dvc.repo.commit import commit
from dvc.repo.destroy import destroy
from dvc.repo.diff import diff
from dvc.repo.fetch import _fetch
from dvc.repo.fetch import fetch
from dvc.repo.freeze import freeze, unfreeze
from dvc.repo.gc import gc
from dvc.repo.get import get
Expand Down Expand Up @@ -607,14 +621,6 @@ def open_by_relpath(self, path, remote=None, mode="r", encoding=None):
def close(self):
self.scm.close()

@locked
def checkout(self, *args, **kwargs):
return self._checkout(*args, **kwargs)

@locked
def fetch(self, *args, **kwargs):
return self._fetch(*args, **kwargs)

def _reset(self):
self.__dict__.pop("graph", None)
self.__dict__.pop("stages", None)
Expand Down
5 changes: 4 additions & 1 deletion dvc/repo/checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from dvc.progress import Tqdm
from dvc.utils import relpath

from . import locked

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -36,7 +38,8 @@ def get_all_files_numbers(pairs):
)


def _checkout(
@locked
def checkout(
self,
targets=None,
with_deps=False,
Expand Down
2 changes: 1 addition & 1 deletion dvc/repo/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def checkout_exp(self, rev):
"""Checkout an experiment to the user's workspace."""
from git.exc import GitCommandError

from dvc.repo.checkout import _checkout as dvc_checkout
from dvc.repo.checkout import checkout as dvc_checkout

self._check_baseline(rev)
self._scm_checkout(rev)
Expand Down
5 changes: 4 additions & 1 deletion dvc/repo/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from dvc.exceptions import DownloadError
from dvc.scm.base import CloneError

from . import locked

logger = logging.getLogger(__name__)


def _fetch(
@locked
def fetch(
self,
targets=None,
jobs=None,
Expand Down
4 changes: 2 additions & 2 deletions dvc/repo/pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def pull(
if isinstance(targets, str):
targets = [targets]

processed_files_count = self._fetch( # pylint: disable=protected-access
processed_files_count = self.fetch(
targets,
jobs,
remote=remote,
Expand All @@ -33,7 +33,7 @@ def pull(
recursive=recursive,
run_cache=run_cache,
)
stats = self._checkout( # pylint: disable=protected-access
stats = self.checkout(
targets=targets, with_deps=with_deps, force=force, recursive=recursive
)

Expand Down
3 changes: 2 additions & 1 deletion tests/unit/repo/test_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ def test_used_cache(tmp_dir, dvc, path):

def test_locked(mocker):
repo = mocker.MagicMock()
repo._lock_depth = 0
repo.method = locked(repo.method)

args = {}
args = ()
kwargs = {}
repo.method(repo, args, kwargs)

Expand Down

0 comments on commit 67ddb7b

Please sign in to comment.