Skip to content

Commit

Permalink
Bug fix of Bidirectional(LSTM(..., stateful=True)) (keras-team#4424)
Browse files Browse the repository at this point in the history
* Bug fix of Bidirectional(LSTM(..., stateful=True)) keras-team#4421

* Add Recurrent.from_config() test
  • Loading branch information
yukoba authored and fchollet committed Nov 18, 2016
1 parent 8653060 commit 04ea01f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
2 changes: 1 addition & 1 deletion keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def get_config(self):
'stateful': self.stateful,
'unroll': self.unroll,
'consume_less': self.consume_less}
if self.stateful:
if self.stateful and self.input_spec[0].shape:
config['batch_input_shape'] = self.input_spec[0].shape
else:
config['input_dim'] = self.input_dim
Expand Down
8 changes: 8 additions & 0 deletions tests/keras/layers/test_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,13 @@ def test_masking_layer():
model.fit(I, V, nb_epoch=1, batch_size=100, verbose=1)


@rnn_test
def test_from_config(layer_class):
for stateful in (False, True):
l1 = layer_class(output_dim=1, stateful=stateful)
l2 = layer_class.from_config(l1.get_config())
assert l1.get_config() == l2.get_config()


if __name__ == '__main__':
pytest.main([__file__])
7 changes: 7 additions & 0 deletions tests/keras/layers/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ def test_Bidirectional():
model.compile(loss='mse', optimizer='sgd')
model.fit(x, y, nb_epoch=1, batch_size=1)

# Bidirectional and stateful
input = Input(batch_shape=(1, timesteps, dim))
output = wrappers.Bidirectional(rnn(output_dim, stateful=True), merge_mode=mode)(input)
model = Model(input, output)
model.compile(loss='mse', optimizer='sgd')
model.fit(x, y, nb_epoch=1, batch_size=1)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 04ea01f

Please sign in to comment.