Skip to content

Commit

Permalink
Move type annotations for remaining torch.utils stub files inline (py…
Browse files Browse the repository at this point in the history
…torch#43406)

Summary: Pull Request resolved: pytorch#43406

Reviewed By: mruberry

Differential Revision: D23319736

Pulled By: malfet

fbshipit-source-id: e25fbb49f27aa4893590b022441303d6d98263a9
  • Loading branch information
rgommers authored and facebook-github-bot committed Sep 1, 2020
1 parent 6022097 commit da32bf4
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 41 deletions.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ ignore_errors = True
[mypy-torch.utils.data._utils.worker]
ignore_errors = True

[mypy-torch.utils.data.distributed]
ignore_errors = True

[mypy-torch.nn.utils.prune]
ignore_errors = True

Expand Down
9 changes: 8 additions & 1 deletion torch/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from .sampler import Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler
from .distributed import DistributedSampler
from .dataset import Dataset, IterableDataset, TensorDataset, ConcatDataset, ChainDataset, Subset, random_split
from .distributed import DistributedSampler
from .dataloader import DataLoader, _DatasetKind, get_worker_info


__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler',
'SubsetRandomSampler', 'WeightedRandomSampler', 'BatchSampler'
'DistributedSampler' 'Dataset', 'IterableDataset', 'TensorDataset',
'ConcatDataset', 'ChainDataset', 'Subset', 'random_split'
'DataLoader', '_DatasetKind', 'get_worker_info']
7 changes: 0 additions & 7 deletions torch/utils/data/__init__.pyi

This file was deleted.

20 changes: 13 additions & 7 deletions torch/utils/data/distributed.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import math
from typing import TypeVar, Optional, Iterator

import torch
from . import Sampler
from . import Sampler, Dataset
import torch.distributed as dist


class DistributedSampler(Sampler):
T_co = TypeVar('T_co', covariant=True)


class DistributedSampler(Sampler[T_co]):
r"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
Expand Down Expand Up @@ -51,7 +56,9 @@ class DistributedSampler(Sampler):
... train(loader)
"""

def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False):
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
rank: Optional[int] = None, shuffle: bool = True,
seed: int = 0, drop_last: bool = False) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
Expand Down Expand Up @@ -80,7 +87,7 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0,
self.shuffle = shuffle
self.seed = seed

def __iter__(self):
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
Expand All @@ -89,7 +96,6 @@ def __iter__(self):
else:
indices = list(range(len(self.dataset)))


if not self.drop_last:
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
Expand All @@ -104,10 +110,10 @@ def __iter__(self):

return iter(indices)

def __len__(self):
def __len__(self) -> int:
return self.num_samples

def set_epoch(self, epoch):
def set_epoch(self, epoch: int) -> None:
r"""
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
Expand Down
9 changes: 0 additions & 9 deletions torch/utils/data/distributed.pyi

This file was deleted.

14 changes: 8 additions & 6 deletions torch/utils/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,29 @@
from collections import OrderedDict
import weakref
import warnings
from typing import Any


class RemovableHandle(object):
"""A handle which provides the capability to remove a hook."""

next_id = 0
id: int
next_id: int = 0

def __init__(self, hooks_dict):
def __init__(self, hooks_dict: Any) -> None:
self.hooks_dict_ref = weakref.ref(hooks_dict)
self.id = RemovableHandle.next_id
RemovableHandle.next_id += 1

def remove(self):
def remove(self) -> None:
hooks_dict = self.hooks_dict_ref()
if hooks_dict is not None and self.id in hooks_dict:
del hooks_dict[self.id]

def __getstate__(self):
return (self.hooks_dict_ref(), self.id)

def __setstate__(self, state):
def __setstate__(self, state) -> None:
if state[0] is None:
# create a dead reference
self.hooks_dict_ref = weakref.ref(OrderedDict())
Expand All @@ -31,10 +33,10 @@ def __setstate__(self, state):
self.id = state[1]
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)

def __enter__(self):
def __enter__(self) -> 'RemovableHandle':
return self

def __exit__(self, type, value, tb):
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
self.remove()


Expand Down
11 changes: 0 additions & 11 deletions torch/utils/hooks.pyi

This file was deleted.

0 comments on commit da32bf4

Please sign in to comment.