From c804ca83ad576b530ca8e588d755704d23afe30a Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Sat, 8 Jun 2019 19:09:57 +0200 Subject: [PATCH] Pass kwargs to wrapped cell in AttentionWrapper (#272) --- tensorflow_addons/seq2seq/attention_wrapper.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/seq2seq/attention_wrapper.py b/tensorflow_addons/seq2seq/attention_wrapper.py index 1f338b3386..3abe62bd8c 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper.py +++ b/tensorflow_addons/seq2seq/attention_wrapper.py @@ -1857,7 +1857,7 @@ def get_initial_state(self, inputs=None, batch_size=None, dtype=None): _alignment_history else () for alignment in initial_alignments)) - def call(self, inputs, state): + def call(self, inputs, state, **kwargs): """Perform a step of attention-wrapped RNN. - Step 1: Mix the `inputs` and previous step's `attention` output via @@ -1878,6 +1878,7 @@ def call(self, inputs, state): step. state: An instance of `AttentionWrapperState` containing tensors from the previous time step. + **kwargs: Dict, other keyword arguments for the cell call method. Returns: A tuple `(attention_or_cell_output, next_state)`, where: @@ -1898,7 +1899,8 @@ def call(self, inputs, state): # previous attention value. cell_inputs = self._cell_input_fn(inputs, state.attention) cell_state = state.cell_state - cell_output, next_cell_state = self._cell(cell_inputs, cell_state) + cell_output, next_cell_state = self._cell( + cell_inputs, cell_state, **kwargs) cell_batch_size = (tf.compat.dimension_value(cell_output.shape[0]) or tf.shape(cell_output)[0])