Skip to content

Commit

Permalink
worktree fetch: support revs flags
Browse files Browse the repository at this point in the history
  • Loading branch information
pmrowla committed Dec 9, 2022
1 parent 7a65d67 commit 959d149
Showing 1 changed file with 59 additions and 20 deletions.
79 changes: 59 additions & 20 deletions dvc/repo/fetch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from contextlib import suppress
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Sequence

from dvc.config import NoRemoteError
from dvc.exceptions import DownloadError
Expand All @@ -9,6 +9,7 @@
from . import locked

if TYPE_CHECKING:
from dvc.cloud import Remote
from dvc.repo import Repo
from dvc.types import TargetType
from dvc_data.hashfile.transfer import TransferResult
Expand Down Expand Up @@ -45,18 +46,19 @@ def fetch(
remote is configured
"""
from dvc.repo.imports import save_imports
from dvc.repo.worktree import fetch_worktree
from dvc_data.hashfile.transfer import TransferResult

if isinstance(targets, str):
targets = [targets]

worktree_remote: Optional["Remote"] = None
with suppress(NoRemoteError):
_remote = self.cloud.get_remote(name=remote)
if _remote.worktree or _remote.fs.version_aware:
return fetch_worktree(self, _remote, targets=targets)
worktree_remote = _remote

failed_count = 0
transferred_count = 0

try:
if run_cache:
Expand All @@ -67,22 +69,35 @@ def fetch(
no_remote_msg: Optional[str] = None
result = TransferResult(set(), set())
try:
d, f = _fetch(
self,
targets,
all_branches=all_branches,
all_tags=all_tags,
all_commits=all_commits,
with_deps=with_deps,
force=True,
remote=remote,
jobs=jobs,
recursive=recursive,
revs=revs,
odb=odb,
)
result.transferred.update(d)
result.failed.update(f)
if worktree_remote is not None:
transferred_count += _fetch_worktree(
self,
worktree_remote,
revs=revs,
all_branches=all_branches,
all_tags=all_tags,
all_commits=all_commits,
targets=targets,
with_deps=with_deps,
recursive=recursive,
)
else:
d, f = _fetch(
self,
targets,
all_branches=all_branches,
all_tags=all_tags,
all_commits=all_commits,
with_deps=with_deps,
force=True,
remote=remote,
jobs=jobs,
recursive=recursive,
revs=revs,
odb=odb,
)
result.transferred.update(d)
result.failed.update(f)
except NoRemoteError as exc:
no_remote_msg = str(exc)

Expand All @@ -108,7 +123,8 @@ def fetch(
logger.error(no_remote_msg)
raise DownloadError(failed_count)

return len(result.transferred)
transferred_count += len(result.transferred)
return transferred_count


def _fetch(
Expand Down Expand Up @@ -155,3 +171,26 @@ def _fetch(
result.transferred.update(d)
result.failed.update(f)
return result


def _fetch_worktree(
repo: "Repo",
remote: "Remote",
revs: Optional[Sequence[str]] = None,
all_branches: bool = False,
all_tags: bool = False,
all_commits: bool = False,
targets: Optional["TargetType"] = None,
**kwargs,
) -> int:
from dvc.repo.worktree import fetch_worktree

downloaded = 0
for _ in repo.brancher(
revs=revs,
all_branches=all_branches,
all_tags=all_tags,
all_commits=all_commits,
):
downloaded += fetch_worktree(repo, remote, targets=targets, **kwargs)
return downloaded

0 comments on commit 959d149

Please sign in to comment.