Skip to content

Commit

Permalink
Remove unnecessary arguments from decoder helper interface
Browse files Browse the repository at this point in the history
  • Loading branch information
huzecong committed May 21, 2019
1 parent d324f75 commit 48a00b9
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 61 deletions.
104 changes: 52 additions & 52 deletions texar/modules/decoders/decoder_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@
# pylint: disable=missing-docstring # does not support generic classes

from abc import ABC
from typing import Generic, Optional, Tuple, TypeVar, Union, Type
from typing import Generic, Optional, Tuple, Type, TypeVar, Union

import torch
from torch import nn
from torch.distributions import Categorical, Gumbel

from texar.core.cell_wrappers import HiddenState
from texar.modules.embedders import EmbedderBase
from texar.modules.embedders import EmbedderBase, WordEmbedder
from texar.utils import utils


__all__ = [
'_convert_embedding',
'Helper',
'TrainingHelper',
'EmbeddingHelper',
'GreedyEmbeddingHelper',
'SampleEmbeddingHelper',
'TopKSampleEmbeddingHelper',
Expand All @@ -46,34 +46,36 @@
# `ScheduledOutputTrainingHelper`

HelperInitTuple = Tuple[torch.ByteTensor, torch.Tensor]
NextInputTuple = Tuple[torch.ByteTensor, torch.Tensor, HiddenState]
NextInputTuple = Tuple[torch.ByteTensor, torch.Tensor]

Embedding = Union[nn.Embedding, torch.Tensor, EmbedderBase]
Embedding = Union[
nn.Embedding, # PyTorch built-in embedding module
EmbedderBase, # Texar embedder base class
]


def _get_raw_embedding(embedding: Embedding) -> nn.Embedding:
r"""Convert raw tensors and Embedders into nn.Embedding.
TODO: After Embedders are implemented, refactor this to convert everything
to Embedders.
def _convert_embedding(embedding: Union[torch.Tensor, Embedding]) -> Embedding:
r"""Wrap raw tensors into Embedder instances. If the input is already an
Embedder instance, or an instance of :class:`~torch.nn.Embedding`, it is
returned as is.
Args:
embedding (nn.Embedding or torch.Tensor or EmbedderBase): the embedding
to convert.
embedding (nn.Embedding or torch.Tensor or EmbedderBase or callable):
the embedding to convert.
Returns:
An instance of nn.Embedding.
An instance of Embedder or nn.Embedding.
"""
if not (callable(embedding) or
torch.is_tensor(embedding) or
isinstance(embedding, EmbedderBase)):
raise ValueError(
"'embedding' must either be a torch.Tensor, a callable, or an "
"Embedder instance.")
if isinstance(embedding, EmbedderBase):
embedding = embedding.embedding
if torch.is_tensor(embedding):
return nn.Embedding(embedding.size(0), embedding.size(1),
_weight=embedding)
embedding = WordEmbedder(init_value=embedding)
elif isinstance(embedding, nn.Embedding):
pass
elif isinstance(embedding, EmbedderBase):
pass
else:
raise ValueError(
"'embedding' must either be a torch.Tensor, an Embedder instance, "
"or an nn.Embedding instance.")
return embedding


Expand Down Expand Up @@ -107,12 +109,11 @@ def initialize(self, inputs: Optional[torch.Tensor],
"""
raise NotImplementedError

def sample(self, time: int, outputs: torch.Tensor,
state: HiddenState) -> IDType:
def sample(self, time: int, outputs: torch.Tensor) -> IDType:
r"""Returns `sample_ids`."""
raise NotImplementedError

def next_inputs(self, time: int, outputs: torch.Tensor, state: HiddenState,
def next_inputs(self, time: int, outputs: torch.Tensor,
sample_ids: IDType) -> NextInputTuple:
r"""Returns `(finished, next_inputs, next_state)`."""
raise NotImplementedError
Expand All @@ -123,7 +124,7 @@ class TrainingHelper(Helper[torch.LongTensor]):
Returned sample_ids are the argmax of the RNN output logits.
"""
_embedding: Optional[nn.Embedding]
_embedding: Optional[Embedding]

# the following are set in `initialize`
_inputs: torch.Tensor
Expand All @@ -148,7 +149,7 @@ def __init__(self, embedding: Optional[Embedding] = None,
ValueError: if `sequence_length` is not a 1D tensor.
"""
if embedding is not None:
self._embedding = _get_raw_embedding(embedding)
self._embedding = _convert_embedding(embedding)
else:
self._embedding = None

Expand Down Expand Up @@ -188,13 +189,12 @@ def initialize(self, inputs: Optional[torch.Tensor],
next_inputs = inputs[0] if not all_finished else self._zero_inputs
return (finished, next_inputs)

def sample(self, time: int, outputs: torch.Tensor,
state: HiddenState) -> torch.LongTensor:
# pylint: disable=unused-variable, no-self-use
def sample(self, time: int, outputs: torch.Tensor) -> torch.LongTensor:
del time
sample_ids = torch.argmax(outputs, dim=-1)
return sample_ids

def next_inputs(self, time: int, outputs: torch.Tensor, state: HiddenState,
def next_inputs(self, time: int, outputs: torch.Tensor,
sample_ids: torch.LongTensor) -> NextInputTuple:
r"""next_inputs_fn for TrainingHelper."""
next_time = time + 1
Expand All @@ -203,7 +203,7 @@ def next_inputs(self, time: int, outputs: torch.Tensor, state: HiddenState,

next_inputs = (self._inputs[next_time] if not all_finished
else self._zero_inputs)
return (finished, next_inputs, state)
return (finished, next_inputs)


class EmbeddingHelper(Helper[IDType], ABC):
Expand Down Expand Up @@ -233,7 +233,7 @@ def __init__(self, embedding: Embedding, start_tokens: torch.LongTensor,
if not isinstance(end_token, int) and end_token.dim() != 0:
raise ValueError("end_token must be a scalar")

self._embedding = _get_raw_embedding(embedding)
self._embedding = _convert_embedding(embedding)

self._start_tokens = start_tokens
self._batch_size = start_tokens.size(0)
Expand All @@ -243,6 +243,10 @@ def __init__(self, embedding: Embedding, start_tokens: torch.LongTensor,
else:
self._end_token = end_token

@property
def batch_size(self) -> int:
return self._batch_size

def initialize(self, inputs: Optional[torch.Tensor],
sequence_length: Optional[torch.LongTensor]) \
-> HelperInitTuple:
Expand All @@ -263,7 +267,7 @@ class SingleEmbeddingHelper(EmbeddingHelper[torch.LongTensor], ABC):
def sample_ids_shape(self) -> torch.Size:
return torch.Size()

def next_inputs(self, time: int, outputs: torch.Tensor, state: HiddenState,
def next_inputs(self, time: int, outputs: torch.Tensor,
sample_ids: torch.LongTensor) -> NextInputTuple:
r"""next_inputs_fn for GreedyEmbeddingHelper."""
del time, outputs # unused by next_inputs_fn
Expand All @@ -272,7 +276,7 @@ def next_inputs(self, time: int, outputs: torch.Tensor, state: HiddenState,

next_inputs = (self._embedding(sample_ids) if not all_finished
else self._zero_inputs)
return (finished, next_inputs, state)
return (finished, next_inputs)


class GreedyEmbeddingHelper(SingleEmbeddingHelper):
Expand All @@ -299,10 +303,9 @@ def __init__(self, embedding: Embedding, start_tokens: torch.LongTensor,
"""
super().__init__(embedding, start_tokens, end_token)

def sample(self, time: int, outputs: torch.Tensor,
state: HiddenState) -> torch.LongTensor:
def sample(self, time: int, outputs: torch.Tensor) -> torch.LongTensor:
r"""sample for GreedyEmbeddingHelper."""
del time, state # unused by sample_fn
del time # unused by sample_fn
# Outputs are logits, use argmax to get the most probable id
if not torch.is_tensor(outputs):
raise TypeError(
Expand Down Expand Up @@ -342,10 +345,9 @@ def __init__(self, embedding: Embedding, start_tokens: torch.LongTensor,
super().__init__(embedding, start_tokens, end_token)
self._softmax_temperature = softmax_temperature

def sample(self, time: int, outputs: torch.Tensor,
state: HiddenState) -> torch.LongTensor:
def sample(self, time: int, outputs: torch.Tensor) -> torch.LongTensor:
r"""sample for SampleEmbeddingHelper."""
del time, state # unused by sample_fn
del time # unused by sample_fn
# Outputs are logits, we sample instead of argmax (greedy).
if not torch.is_tensor(outputs):
raise TypeError(
Expand Down Expand Up @@ -417,10 +419,9 @@ def __init__(self, embedding: Embedding, start_tokens: torch.LongTensor,
self._top_k = top_k
self._softmax_temperature = softmax_temperature

def sample(self, time: int, outputs: torch.Tensor,
state: HiddenState) -> torch.LongTensor:
def sample(self, time: int, outputs: torch.Tensor) -> torch.LongTensor:
r"""sample for SampleEmbeddingHelper."""
del time, state # unused by sample_fn
del time # unused by sample_fn
# Outputs are logits, we sample from the top-k candidates
if not torch.is_tensor(outputs):
raise TypeError(
Expand Down Expand Up @@ -479,16 +480,16 @@ def __init__(self, embedding: Embedding, start_tokens: torch.LongTensor,
def sample_ids_shape(self) -> torch.Size:
return torch.Size([self._embedding.num_embeddings])

def sample(self, time: int, outputs: torch.Tensor,
state: HiddenState) -> torch.Tensor:
def sample(self, time: int, outputs: torch.Tensor) -> torch.Tensor:
r"""Returns `sample_id` which is softmax distributions over vocabulary
with temperature `tau`. Shape = `[batch_size, vocab_size]`
"""
del time
sample_ids = torch.softmax(outputs / self._tau, dim=-1)
return sample_ids

def next_inputs(self, time: int, outputs: torch.Tensor, state: HiddenState,
sample_ids: torch.Tensor) -> NextInputTuple:
def next_inputs(self, time: int, outputs: torch.Tensor,
sample_ids: torch.LongTensor) -> NextInputTuple:
r"""next_inputs_fn for SoftmaxEmbeddingHelper."""
del time, outputs # unused by next_inputs_fn
if self._use_finish:
Expand All @@ -500,7 +501,7 @@ def next_inputs(self, time: int, outputs: torch.Tensor, state: HiddenState,
if self._stop_gradient:
sample_ids = sample_ids.detach()
next_inputs = torch.matmul(sample_ids, self._embedding.weight)
return (finished, next_inputs, state)
return (finished, next_inputs)


class GumbelSoftmaxEmbeddingHelper(SoftmaxEmbeddingHelper):
Expand Down Expand Up @@ -551,8 +552,7 @@ def __init__(self, embedding: Embedding, start_tokens: torch.LongTensor,
self._gumbel = Gumbel(loc=torch.tensor(0.0),
scale=torch.tensor(1.0))

def sample(self, time: int, outputs: torch.Tensor,
state: HiddenState) -> torch.Tensor:
def sample(self, time: int, outputs: torch.Tensor) -> torch.Tensor:
r"""Returns `sample_id` of shape `[batch_size, vocab_size]`. If
`straight_through` is False, this is gumbel softmax distributions over
vocabulary with temperature `tau`. If `straight_through` is True,
Expand Down
2 changes: 1 addition & 1 deletion texar/modules/decoders/rnn_decoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def initialize(self, helper: Helper, inputs: Optional[torch.Tensor],
-> Tuple[torch.ByteTensor, torch.Tensor, HiddenState]:
initial_finished, initial_inputs = helper.initialize(
inputs, sequence_length)
state = initial_state or self._cell.init_batch(initial_inputs.size(0))
state = initial_state or self._cell.init_batch()
return (initial_finished, initial_inputs, state)

def step(self, helper: Helper, time: int,
Expand Down
12 changes: 4 additions & 8 deletions texar/modules/decoders/rnn_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,21 +232,17 @@ def step(self, helper: Helper, time: int,
torch.Tensor, torch.ByteTensor]:
cell_outputs, cell_state = self._cell(inputs, state)
logits = self._output_layer(cell_outputs)
sample_ids = helper.sample(
time=time, outputs=logits, state=cell_state)
(finished, next_inputs, next_state) = helper.next_inputs(
sample_ids = helper.sample(time=time, outputs=logits)
(finished, next_inputs) = helper.next_inputs(
time=time,
outputs=logits,
state=cell_state,
sample_ids=sample_ids)
next_state = cell_state
outputs = BasicRNNDecoderOutput(logits, sample_ids, cell_outputs)
return (outputs, next_state, next_inputs, finished)

@property
def output_size(self):
r"""Output size of one step.
"""
return BasicRNNDecoderOutput(
logits=self._rnn_output_size(),
sample_id=self._helper.sample_ids_shape,
cell_output=self._cell.output_size)
return self._cell.hidden_size

0 comments on commit 48a00b9

Please sign in to comment.