Skip to content

Commit

Permalink
Pass initial_state as a keyword argument.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshua-Chin committed Mar 11, 2017
1 parent 5dbb612 commit 16343b3
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 19 deletions.
85 changes: 69 additions & 16 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,10 @@ class Recurrent(Layer):
a specific layer, or on your entire model.
# Note on specifying initial states in RNNs
Most RNN layers can be called with either a single tensor, or a list
of tensors.
If the an RNN layer is called with a single tensor, that tensor is
treated as the input, and the initial states are assigned an
appropriate default value.
If an RNN layer is called with a list of tensors, the first element is
treated as the input, and the remaining tensors are treated as the
initial states.
You can specify the initial state of RNN layers by calling theme with
the keyword argument `initial_state`. The value of `initial_state`
should be a tensor or list of tensors representing the initial state
of the RNN layer.
"""

def __init__(self, return_sequences=False,
Expand All @@ -188,7 +184,8 @@ def __init__(self, return_sequences=False,
self.unroll = unroll
self.implementation = 0
self.supports_masking = True
self.input_spec = None
self.input_spec = InputSpec(ndim=3)
self.state_spec = None
self.dropout = 0
self.recurrent_dropout = 0

Expand Down Expand Up @@ -224,10 +221,58 @@ def get_initial_states(self, inputs):
def preprocess_input(self, inputs, training=None):
return inputs

def __call__(self, inputs, **kwargs):
with K.name_scope(self.name):
# Handle laying building (weight creating, input spec locking,
# state spec locking)
if not self.built:
# Raise exceptions in case the input is not compatible
# with the input_spec specified in the layer constructor.
self.assert_input_compatibility(inputs)

if hasattr(inputs, '_keras_shape'):
input_shape = inputs._keras_shape
elif hasattr(K, 'int_shape'):
input_shape = K.int_shape(inputs)
else:
raise ValueError('You tried to call layer "' + self.name +
'". This layer has no information'
' about its expected input shape, '
'and thus cannot be built. '
'You can build it manually via: '
'`layer.build(batch_input_shape)`')
self.build(input_shape)
self.built = True

# If initial_state is specified, add it to the inputs and temporarily
# modify the input spec to include the state
if 'initial_state' in kwargs:
# Compute the full input spec, including state
input_spec = self.input_spec
state_spec = self.state_spec
if not isinstance(state_spec, list):
state_spec = [state_spec]
self.input_spec = [input_spec] + state_spec

# Compute the full inputs, including state
initial_state = kwargs.pop('initial_state')
if not isinstance(initial_state, list):
initial_state = [initial_state]
inputs = [inputs] + initial_state

# Perform the call
output = super(Recurrent, self).__call__(inputs, **kwargs)

# Restore original input spec
self.input_spec = input_spec
return output

return super(Recurrent, self).__call__(inputs, **kwargs)

def call(self, inputs, mask=None, training=None):
# input shape: (nbias_samples, time (padded with zeros), input_dim)
# note that the .build() method of subclasses MUST define
# self.input_spec with a complete input shape.
# self.input_spec and self.state_spec with complete input shapes.
if isinstance(inputs, list):
initial_states = inputs[1:]
inputs = inputs[0]
Expand Down Expand Up @@ -277,7 +322,8 @@ def call(self, inputs, mask=None, training=None):
def reset_states(self):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
if not self.batch_size:
batch_size = self.input_spec.shape[0]
if not batch_size:
raise ValueError('If a RNN is stateful, it needs to know '
'its batch size. Specify the batch size '
'of your input tensors: \n'
Expand All @@ -290,9 +336,9 @@ def reset_states(self):
'`batch_shape` argument to your Input layer.')
if self.states[0] is not None:
for state in self.states:
K.set_value(state, np.zeros((self.batch_size, self.units)))
K.set_value(state, np.zeros((batch_size, self.units)))
else:
self.states = [K.zeros((self.batch_size, self.units))
self.states = [K.zeros((batch_size, self.units))
for _ in self.states]

def get_config(self):
Expand Down Expand Up @@ -394,8 +440,10 @@ def build(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]

self.batch_size = input_shape[0]
batch_size = input_shape[0] if self.stateful else None
self.input_dim = input_shape[2]
self.input_spec = InputSpec(shape=(batch_size, None, self.input_dim))
self.state_spec = InputSpec(shape=(batch_size, self.units))

self.states = [None]
if self.stateful:
Expand Down Expand Up @@ -608,8 +656,10 @@ def build(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]

self.batch_size = input_shape[0]
batch_size = input_shape[0] if self.stateful else None
self.input_dim = input_shape[2]
self.input_spec = InputSpec(shape=(batch_size, None, self.input_dim))
self.state_spec = InputSpec(shape=(batch_size, self.units))

self.states = [None]
if self.stateful:
Expand Down Expand Up @@ -883,8 +933,11 @@ def build(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]

self.batch_size = input_shape[0]
batch_size = input_shape[0] if self.stateful else None
self.input_dim = input_shape[2]
self.input_spec = InputSpec(shape=(batch_size, None, self.input_dim))
self.state_spec = [InputSpec(shape=(batch_size, self.units)),
InputSpec(shape=(batch_size, self.units))]

self.states = [None, None]
if self.stateful:
Expand Down
6 changes: 3 additions & 3 deletions tests/keras/layers/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,10 @@ def test_from_config(layer_class):
def test_specify_state(layer_class):
states = 2 if layer_class is recurrent.LSTM else 1
inputs = Input((timesteps, embedding_dim))
initial_states = [Input((units,)) for _ in range(states)]
initial_state = [Input((units,)) for _ in range(states)]
layer = layer_class(units)
output = layer([inputs] + initial_states)
model = Model([inputs] + initial_states, output)
output = layer(inputs, initial_state=initial_state)
model = Model([inputs] + initial_state, output)
model.compile(loss='categorical_crossentropy', optimizer='adam')

inputs = np.random.random((num_samples, timesteps, embedding_dim))
Expand Down

0 comments on commit 16343b3

Please sign in to comment.