Skip to content

Commit

Permalink
Force new instance creation in MultiRNNCell (See also CL 145094809)
Browse files Browse the repository at this point in the history
  • Loading branch information
nealwu committed Jan 20, 2017
1 parent a689e12 commit 520b557
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
13 changes: 8 additions & 5 deletions tutorials/rnn/ptb/ptb_word_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,16 @@ def __init__(self, is_training, config, input_):
# Slightly better results can be obtained with forget gate biases
# initialized to 1 but the hyperparameters of the model would need to be
# different than reported in the paper.
lstm_cell = tf.contrib.rnn.BasicLSTMCell(
size, forget_bias=0.0, state_is_tuple=True)
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(
size, forget_bias=0.0, state_is_tuple=True)
attn_cell = lstm_cell
if is_training and config.keep_prob < 1:
lstm_cell = tf.contrib.rnn.DropoutWrapper(
lstm_cell, output_keep_prob=config.keep_prob)
def attn_cell():
return tf.contrib.rnn.DropoutWrapper(
lstm_cell(), output_keep_prob=config.keep_prob)
cell = tf.contrib.rnn.MultiRNNCell(
[lstm_cell] * config.num_layers, state_is_tuple=True)
[attn_cell() for _ in range(config.num_layers)], state_is_tuple=True)

self._initial_state = cell.zero_state(batch_size, data_type())

Expand Down
10 changes: 6 additions & 4 deletions tutorials/rnn/translate/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,14 @@ def sampled_loss(labels, inputs):
softmax_loss_function = sampled_loss

# Create the internal multi-layer cell for our RNN.
single_cell = tf.nn.rnn_cell.GRUCell(size)
def single_cell():
return tf.nn.rnn_cell.GRUCell(size)
if use_lstm:
single_cell = tf.nn.rnn_cell.BasicLSTMCell(size)
cell = single_cell
def single_cell():
return tf.nn.rnn_cell.BasicLSTMCell(size)
cell = single_cell()
if num_layers > 1:
cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers)
cell = tf.nn.rnn_cell.MultiRNNCell([single_cell() for _ in range(num_layers)])

# The seq2seq function: we use embedding for the input and attention.
def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
Expand Down

0 comments on commit 520b557

Please sign in to comment.