Skip to content

Commit

Permalink
Pass kwargs to wrapped cell in AttentionWrapper (tensorflow#272)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored and seanpmorgan committed Jun 8, 2019
1 parent 84bb63f commit c804ca8
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tensorflow_addons/seq2seq/attention_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,7 +1857,7 @@ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
_alignment_history else ()
for alignment in initial_alignments))

def call(self, inputs, state):
def call(self, inputs, state, **kwargs):
"""Perform a step of attention-wrapped RNN.
- Step 1: Mix the `inputs` and previous step's `attention` output via
Expand All @@ -1878,6 +1878,7 @@ def call(self, inputs, state):
step.
state: An instance of `AttentionWrapperState` containing
tensors from the previous time step.
**kwargs: Dict, other keyword arguments for the cell call method.
Returns:
A tuple `(attention_or_cell_output, next_state)`, where:
Expand All @@ -1898,7 +1899,8 @@ def call(self, inputs, state):
# previous attention value.
cell_inputs = self._cell_input_fn(inputs, state.attention)
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
cell_output, next_cell_state = self._cell(
cell_inputs, cell_state, **kwargs)

cell_batch_size = (tf.compat.dimension_value(cell_output.shape[0])
or tf.shape(cell_output)[0])
Expand Down

0 comments on commit c804ca8

Please sign in to comment.