diff --git a/dvc/repo/fetch.py b/dvc/repo/fetch.py index 26b2226012..924a57737c 100644 --- a/dvc/repo/fetch.py +++ b/dvc/repo/fetch.py @@ -1,6 +1,6 @@ import logging from contextlib import suppress -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Sequence from dvc.config import NoRemoteError from dvc.exceptions import DownloadError @@ -9,6 +9,7 @@ from . import locked if TYPE_CHECKING: + from dvc.cloud import Remote from dvc.repo import Repo from dvc.types import TargetType from dvc_data.hashfile.transfer import TransferResult @@ -45,18 +46,19 @@ def fetch( remote is configured """ from dvc.repo.imports import save_imports - from dvc.repo.worktree import fetch_worktree from dvc_data.hashfile.transfer import TransferResult if isinstance(targets, str): targets = [targets] + worktree_remote: Optional["Remote"] = None with suppress(NoRemoteError): _remote = self.cloud.get_remote(name=remote) if _remote.worktree or _remote.fs.version_aware: - return fetch_worktree(self, _remote, targets=targets) + worktree_remote = _remote failed_count = 0 + transferred_count = 0 try: if run_cache: @@ -67,22 +69,35 @@ def fetch( no_remote_msg: Optional[str] = None result = TransferResult(set(), set()) try: - d, f = _fetch( - self, - targets, - all_branches=all_branches, - all_tags=all_tags, - all_commits=all_commits, - with_deps=with_deps, - force=True, - remote=remote, - jobs=jobs, - recursive=recursive, - revs=revs, - odb=odb, - ) - result.transferred.update(d) - result.failed.update(f) + if worktree_remote is not None: + transferred_count += _fetch_worktree( + self, + worktree_remote, + revs=revs, + all_branches=all_branches, + all_tags=all_tags, + all_commits=all_commits, + targets=targets, + with_deps=with_deps, + recursive=recursive, + ) + else: + d, f = _fetch( + self, + targets, + all_branches=all_branches, + all_tags=all_tags, + all_commits=all_commits, + with_deps=with_deps, + force=True, + remote=remote, + jobs=jobs, + recursive=recursive, + revs=revs, + odb=odb, + ) + result.transferred.update(d) + result.failed.update(f) except NoRemoteError as exc: no_remote_msg = str(exc) @@ -108,7 +123,8 @@ def fetch( logger.error(no_remote_msg) raise DownloadError(failed_count) - return len(result.transferred) + transferred_count += len(result.transferred) + return transferred_count def _fetch( @@ -155,3 +171,26 @@ def _fetch( result.transferred.update(d) result.failed.update(f) return result + + +def _fetch_worktree( + repo: "Repo", + remote: "Remote", + revs: Optional[Sequence[str]] = None, + all_branches: bool = False, + all_tags: bool = False, + all_commits: bool = False, + targets: Optional["TargetType"] = None, + **kwargs, +) -> int: + from dvc.repo.worktree import fetch_worktree + + downloaded = 0 + for _ in repo.brancher( + revs=revs, + all_branches=all_branches, + all_tags=all_tags, + all_commits=all_commits, + ): + downloaded += fetch_worktree(repo, remote, targets=targets, **kwargs) + return downloaded