Skip to content

Commit

Permalink
Bug fix + test : Initializing states for ConvLSTM2D (keras-team#6564)
Browse files Browse the repository at this point in the history
* Bug fix

* Update convolutional_recurrent_test.py

* Update convolutional_recurrent.py
  • Loading branch information
farizrahman4u authored and fchollet committed May 10, 2017
1 parent 672028a commit 2766074
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
12 changes: 7 additions & 5 deletions keras/layers/convolutional_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def __init__(self, filters,
self.return_sequences = return_sequences
self.go_backwards = go_backwards
self.stateful = stateful
self.input_spec = InputSpec(ndim=5)
self.input_spec = [InputSpec(ndim=5)]
self.state_spec = None

def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
Expand Down Expand Up @@ -330,9 +331,10 @@ def __init__(self, filters,
self.recurrent_dropout = min(1., max(0., recurrent_dropout))

def build(self, input_shape):
# TODO: better handling of input spec
self.input_spec = InputSpec(shape=input_shape)

if isinstance(input_shape, list):
input_shape = input_shape[0]
batch_size = input_shape[0] if self.stateful else None
self.input_spec[0] = InputSpec(shape=(batch_size,) + input_shape[1:])
if self.stateful:
self.reset_states()
else:
Expand Down Expand Up @@ -413,7 +415,7 @@ def get_initial_state(self, inputs):
def reset_states(self):
if not self.stateful:
raise RuntimeError('Layer must be stateful.')
input_shape = self.input_spec.shape
input_shape = self.input_spec[0].shape
output_shape = self.compute_output_shape(input_shape)
if not input_shape[0]:
raise ValueError('If a RNN is stateful, a complete '
Expand Down
17 changes: 15 additions & 2 deletions tests/keras/layers/convolutional_recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from numpy.testing import assert_allclose

from keras import backend as K
from keras.models import Sequential
from keras.layers import convolutional_recurrent
from keras.models import Sequential, Model
from keras.layers import convolutional_recurrent, Input
from keras.utils.test_utils import layer_test
from keras import regularizers

Expand Down Expand Up @@ -116,5 +116,18 @@ def test_convolutional_recurrent():
'recurrent_dropout': 0.1},
input_shape=inputs.shape)

# check state initialization
layer = convolutional_recurrent.ConvLSTM2D(filters=filters,
kernel_size=(num_row, num_col),
data_format=data_format,
return_sequences=return_sequences)
layer.build(inputs.shape)
x = Input(batch_shape=inputs.shape)
initial_state = layer.get_initial_state(x)
y = layer(x, initial_state=initial_state)
model = Model(x, y)
assert model.predict(inputs).shape == layer.compute_output_shape(inputs.shape)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 2766074

Please sign in to comment.