diff --git a/keras/layers/recurrent.py b/keras/layers/recurrent.py index f7f0a96ba68c..f77a067d0a52 100644 --- a/keras/layers/recurrent.py +++ b/keras/layers/recurrent.py @@ -165,14 +165,10 @@ class Recurrent(Layer): a specific layer, or on your entire model. # Note on specifying initial states in RNNs - Most RNN layers can be called with either a single tensor, or a list - of tensors. - If the an RNN layer is called with a single tensor, that tensor is - treated as the input, and the initial states are assigned an - appropriate default value. - If an RNN layer is called with a list of tensors, the first element is - treated as the input, and the remaining tensors are treated as the - initial states. + You can specify the initial state of RNN layers by calling theme with + the keyword argument `initial_state`. The value of `initial_state` + should be a tensor or list of tensors representing the initial state + of the RNN layer. """ def __init__(self, return_sequences=False, @@ -188,7 +184,8 @@ def __init__(self, return_sequences=False, self.unroll = unroll self.implementation = 0 self.supports_masking = True - self.input_spec = None + self.input_spec = InputSpec(ndim=3) + self.state_spec = None self.dropout = 0 self.recurrent_dropout = 0 @@ -224,10 +221,58 @@ def get_initial_states(self, inputs): def preprocess_input(self, inputs, training=None): return inputs + def __call__(self, inputs, **kwargs): + with K.name_scope(self.name): + # Handle laying building (weight creating, input spec locking, + # state spec locking) + if not self.built: + # Raise exceptions in case the input is not compatible + # with the input_spec specified in the layer constructor. + self.assert_input_compatibility(inputs) + + if hasattr(inputs, '_keras_shape'): + input_shape = inputs._keras_shape + elif hasattr(K, 'int_shape'): + input_shape = K.int_shape(inputs) + else: + raise ValueError('You tried to call layer "' + self.name + + '". This layer has no information' + ' about its expected input shape, ' + 'and thus cannot be built. ' + 'You can build it manually via: ' + '`layer.build(batch_input_shape)`') + self.build(input_shape) + self.built = True + + # If initial_state is specified, add it to the inputs and temporarily + # modify the input spec to include the state + if 'initial_state' in kwargs: + # Compute the full input spec, including state + input_spec = self.input_spec + state_spec = self.state_spec + if not isinstance(state_spec, list): + state_spec = [state_spec] + self.input_spec = [input_spec] + state_spec + + # Compute the full inputs, including state + initial_state = kwargs.pop('initial_state') + if not isinstance(initial_state, list): + initial_state = [initial_state] + inputs = [inputs] + initial_state + + # Perform the call + output = super(Recurrent, self).__call__(inputs, **kwargs) + + # Restore original input spec + self.input_spec = input_spec + return output + + return super(Recurrent, self).__call__(inputs, **kwargs) + def call(self, inputs, mask=None, training=None): # input shape: (nbias_samples, time (padded with zeros), input_dim) # note that the .build() method of subclasses MUST define - # self.input_spec with a complete input shape. + # self.input_spec and self.state_spec with complete input shapes. if isinstance(inputs, list): initial_states = inputs[1:] inputs = inputs[0] @@ -277,7 +322,8 @@ def call(self, inputs, mask=None, training=None): def reset_states(self): if not self.stateful: raise AttributeError('Layer must be stateful.') - if not self.batch_size: + batch_size = self.input_spec.shape[0] + if not batch_size: raise ValueError('If a RNN is stateful, it needs to know ' 'its batch size. Specify the batch size ' 'of your input tensors: \n' @@ -290,9 +336,9 @@ def reset_states(self): '`batch_shape` argument to your Input layer.') if self.states[0] is not None: for state in self.states: - K.set_value(state, np.zeros((self.batch_size, self.units))) + K.set_value(state, np.zeros((batch_size, self.units))) else: - self.states = [K.zeros((self.batch_size, self.units)) + self.states = [K.zeros((batch_size, self.units)) for _ in self.states] def get_config(self): @@ -394,8 +440,10 @@ def build(self, input_shape): if isinstance(input_shape, list): input_shape = input_shape[0] - self.batch_size = input_shape[0] + batch_size = input_shape[0] if self.stateful else None self.input_dim = input_shape[2] + self.input_spec = InputSpec(shape=(batch_size, None, self.input_dim)) + self.state_spec = InputSpec(shape=(batch_size, self.units)) self.states = [None] if self.stateful: @@ -608,8 +656,10 @@ def build(self, input_shape): if isinstance(input_shape, list): input_shape = input_shape[0] - self.batch_size = input_shape[0] + batch_size = input_shape[0] if self.stateful else None self.input_dim = input_shape[2] + self.input_spec = InputSpec(shape=(batch_size, None, self.input_dim)) + self.state_spec = InputSpec(shape=(batch_size, self.units)) self.states = [None] if self.stateful: @@ -883,8 +933,11 @@ def build(self, input_shape): if isinstance(input_shape, list): input_shape = input_shape[0] - self.batch_size = input_shape[0] + batch_size = input_shape[0] if self.stateful else None self.input_dim = input_shape[2] + self.input_spec = InputSpec(shape=(batch_size, None, self.input_dim)) + self.state_spec = [InputSpec(shape=(batch_size, self.units)), + InputSpec(shape=(batch_size, self.units))] self.states = [None, None] if self.stateful: diff --git a/tests/keras/layers/recurrent_test.py b/tests/keras/layers/recurrent_test.py index 0c46fd6c6db2..c1635367e194 100644 --- a/tests/keras/layers/recurrent_test.py +++ b/tests/keras/layers/recurrent_test.py @@ -170,10 +170,10 @@ def test_from_config(layer_class): def test_specify_state(layer_class): states = 2 if layer_class is recurrent.LSTM else 1 inputs = Input((timesteps, embedding_dim)) - initial_states = [Input((units,)) for _ in range(states)] + initial_state = [Input((units,)) for _ in range(states)] layer = layer_class(units) - output = layer([inputs] + initial_states) - model = Model([inputs] + initial_states, output) + output = layer(inputs, initial_state=initial_state) + model = Model([inputs] + initial_state, output) model.compile(loss='categorical_crossentropy', optimizer='adam') inputs = np.random.random((num_samples, timesteps, embedding_dim))