Skip to content

Commit

Permalink
Transformer encoder refined (asyml#19)
Browse files Browse the repository at this point in the history
* Add TransformerEncoder
* Modify transformer decoder to use same poswise hparam generation function.
  • Loading branch information
TomNong authored and huzecong committed May 29, 2019
1 parent 92a31cb commit 5b22719
Show file tree
Hide file tree
Showing 9 changed files with 594 additions and 85 deletions.
79 changes: 3 additions & 76 deletions texar/modules/decoders/transformer_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from texar.modules.embedders import EmbedderBase
from texar.modules.encoders.multihead_attention import \
Cache, MultiheadAttentionEncoder
from texar.modules.encoders.transformer_encoder import \
default_transformer_poswise_net_hparams
# from texar.utils import beam_search
from texar.utils import get_instance, transformer_attentions as attn
from texar.utils.shapes import mask_sequences
Expand All @@ -40,6 +42,7 @@
]



class TransformerDecoderOutput(NamedTuple):
r"""The output of :class:`TransformerDecoder`.
Expand All @@ -53,82 +56,6 @@ class TransformerDecoderOutput(NamedTuple):
sample_id: torch.LongTensor


def default_transformer_poswise_net_hparams(input_dim, output_dim=512):
"""Returns default hyperparameters of a
:class:`~texar.modules.FeedForwardNetwork` as a pos-wise network used
in :class:`~texar.modules.TransformerEncoder` and
:class:`~texar.modules.TransformerDecoder`.
This is a 2-layer dense network with dropout in-between.
.. code-block:: python
{
"layers": [
{
"type": "Dense",
"kwargs": {
"name": "conv1",
"units": output_dim*4,
"activation": "relu",
"use_bias": True,
}
},
{
"type": "Dropout",
"kwargs": {
"rate": 0.1,
}
},
{
"type": "Dense",
"kwargs": {
"name": "conv2",
"units": output_dim,
"use_bias": True,
}
}
],
"name": "ffn"
}
Args:
output_dim (int): The size of output dense layer.
"""
return {
"layers": [
{
"type": "Linear",
"kwargs": {
"in_features": input_dim,
"out_features": output_dim * 4,
"bias": True,
}
},
{
"type": "ReLU",
"kwargs": {
"inplace": True,
}
},
{
"type": "Dropout",
"kwargs": {
"p": 0.1,
}
},
{
"type": "Linear",
"kwargs": {
"in_features": output_dim * 4,
"out_features": output_dim,
"bias": True,
}
}
],
"name": "ffn"
}


class TransformerDecoder(DecoderBase[Cache, TransformerDecoderOutput]):
r"""Transformer decoder that applies multi-head self-attention for
Expand Down
5 changes: 5 additions & 0 deletions texar/modules/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=empty-docstring
"""
"""

from texar.modules.encoders.encoder_base import *
from texar.modules.encoders.transformer_encoder import *
from texar.modules.encoders.conv_encoders import *
2 changes: 0 additions & 2 deletions texar/modules/encoders/conv_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
Various convolutional network encoders.
"""

import torch

from typing import Dict, Optional, Any

Expand Down Expand Up @@ -53,4 +52,3 @@ def default_hparams() -> Dict[str, Any]:
hparams = Conv1DNetwork.default_hparams()
hparams['name'] = 'conv_encoder'
return hparams

2 changes: 1 addition & 1 deletion texar/modules/encoders/encoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def default_hparams() -> Dict[str, Any]:
'name': 'encoder'
}

# pylint: disable=arguments-differ
def forward(self, # type: ignore
inputs: torch.Tensor, *args, **kwargs) -> torch.Tensor:
r"""Encodes the inputs.
Expand All @@ -50,7 +51,6 @@ def forward(self, # type: ignore
inputs: Inputs to the encoder.
*args: Other arguments.
**kwargs: Keyword arguments.
Returns:
Encoding results.
"""
Expand Down
20 changes: 14 additions & 6 deletions texar/modules/encoders/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
import torch.nn.functional as F
from mypy_extensions import TypedDict
# pylint: disable=ungrouped-imports
from torch import nn

from texar import HParams
Expand All @@ -34,13 +35,17 @@
'Cache',
]


# pylint: disable=empty-docstring
class LayerCache(TypedDict):
"""
"""
keys: MaybeList[torch.Tensor]
values: MaybeList[torch.Tensor]


class Cache(TypedDict):
"""
"""
memory: Optional[torch.Tensor]
memory_attention_bias: Optional[torch.Tensor]
layers: List[LayerCache]
Expand Down Expand Up @@ -73,6 +78,7 @@ def __init__(self, input_size: int, hparams: Optional[HParams] = None):
self._hparams.output_dim, bias=use_bias)

if self._hparams.initializer:
# pylint: disable=fixme
# TODO: This might be different to what TensorFlow does
initialize = layers.get_initializer(self._hparams.initializer)
assert initialize is not None
Expand All @@ -84,7 +90,6 @@ def default_hparams():
r"""Returns a dictionary of hyperparameters with default values.
.. code-block:: python
{
"initializer": None,
'num_heads': 8,
Expand Down Expand Up @@ -131,6 +136,7 @@ def default_hparams():
'name': 'multihead_attention',
}

# pylint: disable=arguments-differ, too-many-locals
def forward(self, # type: ignore
queries: torch.Tensor,
memory: torch.Tensor,
Expand Down Expand Up @@ -224,9 +230,8 @@ def _update_and_return(layer: nn.Module, key: str):

def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
r"""Split channels (dimension 2) into multiple heads,
becomes dimension 1).
Must ensure `x.shape[-1]` can be divided by num_heads
becomes dimension 1). Must ensure `x.shape[-1]` can be
divided by num_heads.
"""
depth = x.size(-1)
split_x = torch.reshape(x, (
Expand All @@ -236,9 +241,9 @@ def _split_heads(self, x: torch.Tensor) -> torch.Tensor:

def _combine_heads(self, x: torch.Tensor) -> torch.Tensor:
r"""
Args:
x: A Tensor of shape `[batch, num_heads, seq_len, dim]`
Returns:
A Tensor of shape `[batch, seq_len, num_heads * dim]`
"""
Expand All @@ -249,4 +254,7 @@ def _combine_heads(self, x: torch.Tensor) -> torch.Tensor:

@property
def output_size(self):
r"""Provides output dimension as property.
"""

return self._hparams.output_dim
Loading

0 comments on commit 5b22719

Please sign in to comment.