Skip to content

Commit

Permalink
add type annotations to torch.nn.modules.container (pytorch#48969)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#48968

Pull Request resolved: pytorch#48969

Reviewed By: mrshenli

Differential Revision: D25728987

Pulled By: walterddr

fbshipit-source-id: 02c3aa2078f4ed6cc6edd90ffe1177d789c328a9
  • Loading branch information
guilhermeleobas authored and facebook-github-bot committed Jan 19, 2021
1 parent a1b1d0c commit a9e46f1
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 17 deletions.
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ ignore_errors = True
[mypy-torch.testing._internal.distributed.*]
ignore_errors = True

[mypy-torch.nn.modules.container]
ignore_errors = True

[mypy-torch.nn.modules.pooling]
ignore_errors = True

Expand Down
32 changes: 32 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,38 @@ def get_element_size_script(x):
element_size = get_element_size_script(x)
self.assertEqual(element_size, x.element_size())

def test_Sequential(self):
class Seq(nn.Module):
def __init__(self):
super(Seq, self).__init__()
self.seq = nn.Sequential(nn.Linear(10, 20), nn.Linear(20, 30))

@torch.jit.script_method
def forward(self, x):
for l in self.seq:
x = l(x)
return x

m = torch.jit.script(Seq())
assert m.graph # ensure jit was able to compile

def test_ModuleList(self):
class Mod(nn.Module):
def __init__(self):
super(Mod, self).__init__()
self.model = nn.ModuleList([nn.Linear(10, 10) for _ in range(10)])
self.model += (nn.Linear(10, 20),)
self.model.append(nn.Linear(20, 30))
self.model.extend([nn.Linear(30, 40), nn.Linear(40, 50)])

def forward(self, v):
for m in self.model:
v = m(v)
return v

m = torch.jit.script(Mod())
assert m.graph # ensure jit was able to compile

def test_disabled(self):
torch.jit._state.disable()
try:
Expand Down
34 changes: 20 additions & 14 deletions torch/nn/modules/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from .module import Module
from torch._jit_internal import _copy_to_script_wrapper

from typing import Any, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
from typing import Any, Iterable, Iterator, Mapping, Optional, TYPE_CHECKING, overload, Tuple, TypeVar, Union

if TYPE_CHECKING:
from torch.nn import Parameter

T = TypeVar('T')
T = TypeVar('T', bound=Module)


class Container(Module):
Expand Down Expand Up @@ -57,7 +59,7 @@ def __init__(self, *args: Module) -> None:
def __init__(self, arg: 'OrderedDict[str, Module]') -> None:
...

def __init__(self, *args: Any):
def __init__(self, *args):
super(Sequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
Expand All @@ -66,7 +68,7 @@ def __init__(self, *args: Any):
for idx, module in enumerate(args):
self.add_module(str(idx), module)

def _get_item_by_idx(self, iterator, idx):
def _get_item_by_idx(self, iterator, idx) -> T:
"""Get the idx-th item of the iterator"""
size = len(self)
idx = operator.index(idx)
Expand All @@ -76,14 +78,14 @@ def _get_item_by_idx(self, iterator, idx):
return next(islice(iterator, idx, None))

@_copy_to_script_wrapper
def __getitem__(self: T, idx) -> T:
def __getitem__(self, idx) -> Union['Sequential', T]:
if isinstance(idx, slice):
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
else:
return self._get_item_by_idx(self._modules.values(), idx)

def __setitem__(self, idx: int, module: Module) -> None:
key = self._get_item_by_idx(self._modules.keys(), idx)
key: str = self._get_item_by_idx(self._modules.keys(), idx)
return setattr(self, key, module)

def __delitem__(self, idx: Union[slice, int]) -> None:
Expand Down Expand Up @@ -185,7 +187,7 @@ def __len__(self) -> int:
def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())

def __iadd__(self: T, modules: Iterable[Module]) -> T:
def __iadd__(self, modules: Iterable[Module]) -> 'ModuleList':
return self.extend(modules)

@_copy_to_script_wrapper
Expand All @@ -205,7 +207,7 @@ def insert(self, index: int, module: Module) -> None:
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module

def append(self: T, module: Module) -> T:
def append(self, module: Module) -> 'ModuleList':
r"""Appends a given module to the end of the list.
Args:
Expand All @@ -214,7 +216,7 @@ def append(self: T, module: Module) -> T:
self.add_module(str(len(self)), module)
return self

def extend(self: T, modules: Iterable[Module]) -> T:
def extend(self, modules: Iterable[Module]) -> 'ModuleList':
r"""Appends modules from a Python iterable to the end of the list.
Args:
Expand Down Expand Up @@ -357,6 +359,7 @@ def update(self, modules: Mapping[str, Module]) -> None:
for key, module in modules.items():
self[key] = module
else:
# modules here can be a list with two items
for j, m in enumerate(modules):
if not isinstance(m, container_abcs.Iterable):
raise TypeError("ModuleDict update sequence element "
Expand All @@ -366,7 +369,9 @@ def update(self, modules: Mapping[str, Module]) -> None:
raise ValueError("ModuleDict update sequence element "
"#" + str(j) + " has length " + str(len(m)) +
"; 2 is required")
self[m[0]] = m[1]
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
# that's too cumbersome to type correctly with overloads, so we add an ignore here
self[m[0]] = m[1] # type: ignore[assignment]

def forward(self):
raise NotImplementedError()
Expand Down Expand Up @@ -447,15 +452,15 @@ def __len__(self) -> int:
def __iter__(self) -> Iterator['Parameter']:
return iter(self._parameters.values())

def __iadd__(self: T, parameters: Iterable['Parameter']) -> T:
def __iadd__(self, parameters: Iterable['Parameter']) -> 'ParameterList':
return self.extend(parameters)

def __dir__(self):
keys = super(ParameterList, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys

def append(self: T, parameter: 'Parameter') -> T:
def append(self, parameter: 'Parameter') -> 'ParameterList':
"""Appends a given parameter at the end of the list.
Args:
Expand All @@ -464,7 +469,7 @@ def append(self: T, parameter: 'Parameter') -> T:
self.register_parameter(str(len(self)), parameter)
return self

def extend(self: T, parameters: Iterable['Parameter']) -> T:
def extend(self, parameters: Iterable['Parameter']) -> 'ParameterList':
"""Appends parameters from a Python iterable to the end of the list.
Args:
Expand Down Expand Up @@ -637,7 +642,8 @@ def update(self, parameters: Mapping[str, 'Parameter']) -> None:
raise ValueError("ParameterDict update sequence element "
"#" + str(j) + " has length " + str(len(p)) +
"; 2 is required")
self[p[0]] = p[1]
# parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
self[p[0]] = p[1] # type: ignore[assignment]

def extra_repr(self) -> str:
child_lines = []
Expand Down

0 comments on commit a9e46f1

Please sign in to comment.