Skip to content

Commit

Permalink
types: add types to dvc/tree (iterative#4947)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Nov 24, 2020
1 parent 3ba70d3 commit 04c5360
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 22 deletions.
14 changes: 8 additions & 6 deletions dvc/tree/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
from multiprocessing import cpu_count
from typing import Any, ClassVar, Dict, Optional
from urllib.parse import urlparse

from funcy import cached_property, decorator
Expand Down Expand Up @@ -46,8 +47,8 @@ def use_state(call):

class BaseTree:
scheme = "base"
REQUIRES = {}
PATH_CLS = URLInfo
REQUIRES: ClassVar[Dict[str, str]] = {}
PATH_CLS = URLInfo # type: Any
JOBS = 4 * cpu_count()

CHECKSUM_DIR_SUFFIX = ".dir"
Expand All @@ -59,9 +60,9 @@ class BaseTree:
TRAVERSE_THRESHOLD_SIZE = 500000
CAN_TRAVERSE = True

CACHE_MODE = None
CACHE_MODE: Optional[int] = None
SHARED_MODE_MAP = {None: (None, None), "group": (None, None)}
PARAM_CHECKSUM = None
PARAM_CHECKSUM: ClassVar[Optional[str]] = None

state = StateNoop()

Expand Down Expand Up @@ -155,9 +156,10 @@ def dir_mode(self):
def cache(self):
return getattr(self.repo.cache, self.scheme)

def open(self, path_info, mode="r", encoding=None):
def open(self, path_info, mode: str = "r", encoding: str = None):
if hasattr(self, "_generate_download_url"):
func = self._generate_download_url # noqa,pylint:disable=no-member
# pylint:disable=no-member
func = self._generate_download_url # type: ignore[attr-defined]
get_url = partial(func, path_info)
return open_url(get_url, mode=mode, encoding=encoding)

Expand Down
1 change: 0 additions & 1 deletion dvc/tree/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class DvcTree(BaseTree): # pylint:disable=abstract-method

scheme = "local"
PARAM_CHECKSUM = "md5"
_dir_entry_hashes = {}

def __init__(self, repo, fetch=False, stream=False):
super().__init__(repo, {"url": repo.root_dir})
Expand Down
6 changes: 5 additions & 1 deletion dvc/tree/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import stat
from typing import Any, Dict

from funcy import cached_property
from shortuuid import uuid
Expand All @@ -27,7 +28,10 @@ class LocalTree(BaseTree):
TRAVERSE_PREFIX_LEN = 2

CACHE_MODE = 0o444
SHARED_MODE_MAP = {None: (0o644, 0o755), "group": (0o664, 0o775)}
SHARED_MODE_MAP: Dict[Any, Any] = {
None: (0o644, 0o755),
"group": (0o664, 0o775),
}

def __init__(self, repo, config, use_dvcignore=False, dvcignore_root=None):
super().__init__(repo, config)
Expand Down
20 changes: 6 additions & 14 deletions dvc/tree/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import threading
from contextlib import suppress
from itertools import takewhile
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Type, Union

from funcy import lfilter, wrap_with

Expand All @@ -22,12 +22,10 @@
if TYPE_CHECKING:
from dvc.repo import Repo

from .git import GitTree
from .local import LocalTree


logger = logging.getLogger(__name__)

RepoFactory = Union[Callable[[str], "Repo"], Type["Repo"]]


class RepoTree(BaseTree): # pylint:disable=abstract-method
"""DVC + git-tracked files tree.
Expand All @@ -43,18 +41,14 @@ class RepoTree(BaseTree): # pylint:disable=abstract-method
PARAM_CHECKSUM = "md5"

def __init__(
self,
repo,
subrepos=False,
repo_factory: Callable[[str], "Repo"] = None,
**kwargs
self, repo, subrepos=False, repo_factory: RepoFactory = None, **kwargs
):
super().__init__(repo, {"url": repo.root_dir})

if not repo_factory:
from dvc.repo import Repo

self.repo_factory = Repo
self.repo_factory: RepoFactory = Repo
else:
self.repo_factory = repo_factory

Expand Down Expand Up @@ -120,9 +114,7 @@ def _is_dvc_repo(self, dir_path):
# dvcignore will ignore subrepos, therefore using `use_dvcignore=False`
return self._main_repo.tree.isdir(repo_path, use_dvcignore=False)

def _get_tree_pair(
self, path
) -> Tuple[Union["GitTree", "LocalTree"], DvcTree]:
def _get_tree_pair(self, path) -> Tuple[BaseTree, Optional[DvcTree]]:
"""
Returns a pair of trees based on repo the path falls in, using prefix.
"""
Expand Down

0 comments on commit 04c5360

Please sign in to comment.