Skip to content

Commit

Permalink
Bug fix in convolutional recurrent state setting
Browse files Browse the repository at this point in the history
* Bug fix: convolutional recurrent (again)

* pep8

* Update convolutional_recurrent.py

* pep8
  • Loading branch information
farizrahman4u authored and fchollet committed May 14, 2017
1 parent 6220e35 commit 0d27d90
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions keras/layers/convolutional_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def __init__(self, filters,
self.state_spec = None

def compute_output_shape(self, input_shape):
if type(input_shape) is list:
input_shape = input_shape[0]
if self.data_format == 'channels_first':
rows = input_shape[3]
cols = input_shape[4]
Expand Down Expand Up @@ -329,6 +331,7 @@ def __init__(self, filters,

self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.state_spec = [InputSpec(ndim=4), InputSpec(ndim=4)]

def build(self, input_shape):
if isinstance(input_shape, list):
Expand All @@ -349,6 +352,10 @@ def build(self, input_shape):
raise ValueError('The channel dimension of the inputs '
'should be defined. Found `None`.')
input_dim = input_shape[channel_axis]
state_shape = [None] * 4
state_shape[channel_axis] = input_dim
state_shape = tuple(state_shape)
self.state_spec = [InputSpec(shape=state_shape), InputSpec(shape=state_shape)]
kernel_shape = self.kernel_size + (input_dim, self.filters * 4)
self.kernel_shape = kernel_shape
recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4)
Expand Down

0 comments on commit 0d27d90

Please sign in to comment.