Skip to content

Commit

Permalink
Type annotations almost everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
huzecong committed Mar 29, 2019
1 parent ebcda15 commit a5f8f74
Show file tree
Hide file tree
Showing 15 changed files with 366 additions and 318 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ docs/_build
### pytest ###
/.pytest_cache/

### mypy ###
/.mypy_cache/

### Project ###
/data/
checkpoints/
Expand Down
18 changes: 18 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Global options:

[mypy]
python_version = 3.7
warn_unused_ignores = True
warn_unused_configs = True
no_implicit_optional = True
python_executable = venv/bin/python
follow_imports = silent
ignore_missing_imports = True
mypy_path = ./,venv/lib/python3.7/site-packages
allow_redefinition = True

[mypy-torch]
ignore_errors = True

[mypy-numpy]
follow_imports = skip
59 changes: 32 additions & 27 deletions texar/core/cell_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@
"""

# pylint: disable=redefined-builtin, arguments-differ, too-many-arguments
# pylint: disable=invalid-name
# pylint: disable=invalid-name, too-few-public-methods

from typing import Optional, List, Tuple, Union
from typing import Generic, List, Optional, Tuple, TypeVar, Union

import torch
from torch import nn
import torch.nn.functional as F
from torch import nn

from texar.utils import utils
from texar.utils.types import MaybeList

__all__ = [
'HiddenState',
'wrap_builtin_cell',
'RNNCellBase',
'RNNCell',
Expand All @@ -38,10 +40,11 @@
'MultiRNNCell',
]

State = TypeVar('State')
RNNState = torch.Tensor
LSTMState = Tuple[torch.Tensor, torch.Tensor]
State = Union[RNNState, LSTMState]
MultiState = Union[State, List[State]]

HiddenState = MaybeList[Union[RNNState, LSTMState]]


def wrap_builtin_cell(cell: nn.RNNCellBase):
Expand All @@ -64,11 +67,11 @@ def wrap_builtin_cell(cell: nn.RNNCellBase):
self = RNNCellBase.__new__(LSTMCell)
else:
raise TypeError(f"Unrecognized class {type(cell)}.")
self._cell = cell # pylint: disable=protected-access
self._cell = cell # pylint: disable=protected-access, attribute-defined-outside-init
return self


class RNNCellBase(nn.Module):
class RNNCellBase(nn.Module, Generic[State]):
r"""The base class for RNN cells in our framework. Major differences over
:class:`torch.nn.RNNCell` are two-fold::
Expand Down Expand Up @@ -115,7 +118,7 @@ def init_batch(self, batch_size: int):
batch_size: int, the batch size.
"""

def zero_state(self, batch_size: int) -> MultiState:
def zero_state(self, batch_size: int) -> State:
r"""Return zero-filled state tensor(s).
Args:
Expand All @@ -133,8 +136,8 @@ def zero_state(self, batch_size: int) -> MultiState:
state = self._cell.zero_state(batch_size)
return state

def forward(self, input: torch.Tensor, state: Optional[MultiState] = None) \
-> Tuple[torch.Tensor, MultiState]:
def forward(self, input: torch.Tensor, state: Optional[State] = None) \
-> Tuple[torch.Tensor, State]:
r"""
Returns:
A tuple of (output, state). For single layer RNNs, output is
Expand All @@ -146,7 +149,7 @@ def forward(self, input: torch.Tensor, state: Optional[MultiState] = None) \
return self._cell(input, state)


class BuiltinCellWrapper(RNNCellBase):
class BuiltinCellWrapper(RNNCellBase[State]):
r"""Base class for wrappers over built-in :class:`torch.nn.RNNCellBase`
RNN cells.
"""
Expand All @@ -160,7 +163,7 @@ def forward(self, input: torch.Tensor, state: Optional[State] = None) \
return new_state, new_state


class RNNCell(BuiltinCellWrapper):
class RNNCell(BuiltinCellWrapper[RNNState]):
r"""A wrapper over :class:`torch.nn.RNNCell`."""

def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"):
Expand All @@ -169,15 +172,15 @@ def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"):
super().__init__(cell)


class GRUCell(BuiltinCellWrapper):
class GRUCell(BuiltinCellWrapper[RNNState]):
r"""A wrapper over :class:`torch.nn.GRUCell`."""

def __init__(self, input_size, hidden_size, bias=True):
cell = nn.GRUCell(input_size, hidden_size, bias=bias)
super().__init__(cell)


class LSTMCell(BuiltinCellWrapper):
class LSTMCell(BuiltinCellWrapper[LSTMState]):
r"""A wrapper over :class:`torch.nn.LSTMCell`, additionally providing the
option to initialize the forget-gate bias to a constant value.
"""
Expand All @@ -195,8 +198,10 @@ def __init__(self, input_size, hidden_size, bias=True,
super().__init__(cell)

def zero_state(self, batch_size: int) -> LSTMState:
state = super().zero_state(batch_size)
return (state, state) # (h, c)
r"""Returns the zero state for LSTMs as (h, c)."""
state = self._param.new_zeros(
batch_size, self.hidden_size, requires_grad=False)
return (state, state)

def forward(self, input: torch.Tensor, state: Optional[LSTMState] = None) \
-> Tuple[torch.Tensor, LSTMState]:
Expand All @@ -207,10 +212,10 @@ def forward(self, input: torch.Tensor, state: Optional[LSTMState] = None) \
return new_state[0], new_state


class DropoutWrapper(RNNCellBase):
class DropoutWrapper(RNNCellBase[State]):
r"""Operator adding dropout to inputs and outputs of the given cell."""

def __init__(self, cell: nn.Module,
def __init__(self, cell: RNNCellBase[State],
input_keep_prob: float = 1.0,
output_keep_prob: float = 1.0,
state_keep_prob: float = 1.0,
Expand Down Expand Up @@ -258,9 +263,9 @@ def __init__(self, cell: nn.Module,
self._state_keep_prob = state_keep_prob

self._variational_recurrent = variational_recurrent
self._recurrent_input_mask = None
self._recurrent_output_mask = None
self._recurrent_state_mask = None
self._recurrent_input_mask: Optional[torch.Tensor] = None
self._recurrent_output_mask: Optional[torch.Tensor] = None
self._recurrent_state_mask: Optional[torch.Tensor] = None

def init_batch(self, batch_size: int):
r"""Initialize dropout masks for variational dropout.
Expand Down Expand Up @@ -307,7 +312,7 @@ def forward(self, input: torch.Tensor, state: Optional[State] = None) \
return output, new_state


class ResidualWrapper(RNNCellBase):
class ResidualWrapper(RNNCellBase[State]):
r"""RNNCell wrapper that ensures cell inputs are added to the outputs."""

def forward(self, input: torch.Tensor, state: Optional[State] = None) \
Expand All @@ -317,7 +322,7 @@ def forward(self, input: torch.Tensor, state: Optional[State] = None) \
return output, new_state


class HighwayWrapper(RNNCellBase):
class HighwayWrapper(RNNCellBase[State]):
r"""RNNCell wrapper that adds highway connection on cell input and output.
Based on:
Expand Down Expand Up @@ -360,7 +365,7 @@ def forward(self, input: torch.Tensor, state: Optional[State] = None) \
return output, new_state


class MultiRNNCell(RNNCellBase):
class MultiRNNCell(RNNCellBase[List[State]]):
r"""RNN cell composed sequentially of multiple simple cells.
.. code-block:: python
Expand All @@ -371,9 +376,9 @@ class MultiRNNCell(RNNCellBase):
stacked_rnn_cell = MultiRNNCell(cells)
"""

_cell: List[RNNCellBase] # for better autocompletion
_cell: nn.ModuleList

def __init__(self, cells: List[nn.Module]):
def __init__(self, cells: List[RNNCellBase[State]]):
r"""Create a RNN cell composed sequentially of a number of RNNCells.
Args:
Expand All @@ -399,7 +404,7 @@ def init_batch(self, batch_size: int):
for cell in self._cell:
cell.init_batch(batch_size)

def zero_state(self, batch_size: int):
def zero_state(self, batch_size: int) -> List[State]:
states = [cell.zero_state(batch_size) for cell in self._cell]
return states

Expand Down
8 changes: 4 additions & 4 deletions texar/core/cell_wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
:func:`~texar.core.layers.get_rnn_cell`.
"""

# pylint: disable=too-many-locals, protected-access, unused-variable
# pylint: disable=redefined-builtin, invalid-name

import unittest

import torch
from torch import nn

import texar.core.cell_wrappers as wrappers
from texar import HParams
from texar.core.layers import get_rnn_cell, default_rnn_cell_hparams
from texar.core.layers import default_rnn_cell_hparams, get_rnn_cell
from texar.utils import utils


# pylint: disable=too-many-locals, protected-access, unused-variable
# pylint: disable=redefined-builtin, invalid-name

class WrappersTest(unittest.TestCase):
r"""Tests cell wrappers and :func:`~texar.core.layers.get_rnn_cell`.
"""
Expand Down
30 changes: 15 additions & 15 deletions texar/core/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
Various neural network layers
"""

# pylint: disable=too-many-branches

import torch
from torch import nn

import texar.core.cell_wrappers as wrappers
from texar.hyperparams import HParams
from texar.utils import utils

# pylint: disable=too-many-branches

__all__ = [
"default_rnn_cell_hparams",
"get_rnn_cell",
Expand Down Expand Up @@ -116,21 +116,21 @@ def default_rnn_cell_hparams():
"num_layers" = 1.
"""
return {
"type": "LSTMCell",
"input_size": 256,
"kwargs": {
"hidden_size": 256,
'type': 'LSTMCell',
'input_size': 256,
'kwargs': {
'hidden_size': 256,
},
"num_layers": 1,
"dropout": {
"input_keep_prob": 1.0,
"output_keep_prob": 1.0,
"state_keep_prob": 1.0,
"variational_recurrent": False,
'num_layers': 1,
'dropout': {
'input_keep_prob': 1.0,
'output_keep_prob': 1.0,
'state_keep_prob': 1.0,
'variational_recurrent': False,
},
"residual": False,
"highway": False,
"@no_typecheck": ["type"]
'residual': False,
'highway': False,
'@no_typecheck': ['type']
}


Expand Down
Loading

0 comments on commit a5f8f74

Please sign in to comment.