Skip to content

Commit

Permalink
Style fixes in layer wrappers.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jan 13, 2017
1 parent 1b7800a commit 8ef4a3d
Showing 1 changed file with 47 additions and 40 deletions.
87 changes: 47 additions & 40 deletions keras/layers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@


class Wrapper(Layer):
"""Abstract wrapper base class.
"""

def __init__(self, layer, **kwargs):
self.layer = layer
self.uses_learning_phase = layer.uses_learning_phase
super(Wrapper, self).__init__(**kwargs)

def build(self, input_shape=None):
"""Assumes that self.layer is already set.
Should be called at the end of .build() in the
children classes.
"""
# Assumes that self.layer is already set.
# Should be called at the end of .build() in the children classes.
self.trainable_weights = getattr(self.layer, 'trainable_weights', [])
self.non_trainable_weights = getattr(self.layer, 'non_trainable_weights', [])
self.updates = getattr(self.layer, 'updates', [])
Expand Down Expand Up @@ -41,18 +41,19 @@ def from_config(cls, config):


class TimeDistributed(Wrapper):
"""This wrapper allows to apply a layer to every
temporal slice of an input.
"""This wrapper allows to apply a layer to every temporal slice of an input.
The input should be at least 3D, and the dimension of index one
will be considered to be the temporal dimension.
The input should be at least 3D,
and the dimension of index one will be considered to be
the temporal dimension.
Consider a batch of 32 samples,
where each sample is a sequence of 10 vectors of 16 dimensions.
The batch input shape of the layer is then `(32, 10, 16)`,
and the `input_shape`, not including the samples dimension, is `(10, 16)`.
Consider a batch of 32 samples, where each sample is a sequence of 10
vectors of 16 dimensions. The batch input shape of the layer is then `(32, 10, 16)`
(and the `input_shape`, not including the samples dimension, is `(10, 16)`).
You can then use `TimeDistributed` to apply a `Dense` layer
to each of the 10 timesteps, independently:
You can then use `TimeDistributed` to apply a `Dense` layer to each of the 10 timesteps, independently:
```python
# as the first layer in a model
model = Sequential()
Expand All @@ -66,14 +67,16 @@ class TimeDistributed(Wrapper):
The output will then have shape `(32, 10, 8)`.
Note this is strictly equivalent to using `layers.core.TimeDistributedDense`.
Note this is strictly equivalent to
using `layers.core.TimeDistributedDense`.
However what is different about `TimeDistributed`
is that it can be used with arbitrary layers, not just `Dense`,
for instance with a `Convolution2D` layer:
```python
model = Sequential()
model.add(TimeDistributed(Convolution2D(64, 3, 3), input_shape=(10, 3, 299, 299)))
model.add(TimeDistributed(Convolution2D(64, 3, 3),
input_shape=(10, 3, 299, 299)))
```
# Arguments
Expand All @@ -99,15 +102,15 @@ def get_output_shape_for(self, input_shape):
timesteps = input_shape[1]
return (child_output_shape[0], timesteps) + child_output_shape[1:]

def call(self, X, mask=None):
input_shape = K.int_shape(X)
def call(self, inputs, mask=None):
input_shape = K.int_shape(inputs)
if input_shape[0]:
# batch size matters, use rnn-based implementation
def step(x, states):
def step(x, _):
output = self.layer.call(x)
return output, []

_, outputs, _ = K.rnn(step, X,
_, outputs, _ = K.rnn(step, inputs,
initial_states=[],
input_length=input_shape[1],
unroll=False)
Expand All @@ -118,22 +121,24 @@ def step(x, states):
# we can go with reshape-based implementation for performance
input_length = input_shape[1]
if not input_length:
input_length = K.shape(X)[1]
X = K.reshape(X, (-1,) + input_shape[2:]) # (nb_samples * timesteps, ...)
y = self.layer.call(X) # (nb_samples * timesteps, ...)
input_length = K.shape(inputs)[1]
# (nb_samples * timesteps, ...)
inputs = K.reshape(inputs, (-1,) + input_shape[2:])
y = self.layer.call(inputs) # (nb_samples * timesteps, ...)
# (nb_samples, timesteps, ...)
output_shape = self.get_output_shape_for(input_shape)
y = K.reshape(y, (-1, input_length) + output_shape[2:])

# Apply activity regularizer if any:
if hasattr(self.layer, 'activity_regularizer') and self.layer.activity_regularizer is not None:
if (hasattr(self.layer, 'activity_regularizer') and
self.layer.activity_regularizer is not None):
regularization_loss = self.layer.activity_regularizer(y)
self.add_loss(regularization_loss, X)
self.add_loss(regularization_loss, inputs)
return y


class Bidirectional(Wrapper):
""" Bidirectional wrapper for RNNs.
"""Bidirectional wrapper for RNNs.
# Arguments
layer: `Recurrent` instance.
Expand Down Expand Up @@ -194,21 +199,21 @@ def get_output_shape_for(self, input_shape):
elif self.merge_mode is None:
return [self.forward_layer.get_output_shape_for(input_shape)] * 2

def call(self, X, mask=None):
Y = self.forward_layer.call(X, mask)
Y_rev = self.backward_layer.call(X, mask)
def call(self, inputs, mask=None):
y = self.forward_layer.call(inputs, mask)
y_rev = self.backward_layer.call(inputs, mask)
if self.return_sequences:
Y_rev = K.reverse(Y_rev, 1)
y_rev = K.reverse(y_rev, 1)
if self.merge_mode == 'concat':
return K.concatenate([Y, Y_rev])
return K.concatenate([y, y_rev])
elif self.merge_mode == 'sum':
return Y + Y_rev
return y + y_rev
elif self.merge_mode == 'ave':
return (Y + Y_rev) / 2
return (y + y_rev) / 2
elif self.merge_mode == 'mul':
return Y * Y_rev
return y * y_rev
elif self.merge_mode is None:
return [Y, Y_rev]
return [y, y_rev]

def reset_states(self):
self.forward_layer.reset_states()
Expand All @@ -230,13 +235,15 @@ def compute_mask(self, input, mask):
@property
def trainable_weights(self):
if hasattr(self.forward_layer, 'trainable_weights'):
return self.forward_layer.trainable_weights + self.backward_layer.trainable_weights
return (self.forward_layer.trainable_weights +
self.backward_layer.trainable_weights)
return []

@property
def non_trainable_weights(self):
if hasattr(self.forward_layer, 'non_trainable_weights'):
return self.forward_layer.non_trainable_weights + self.backward_layer.non_trainable_weights
return (self.forward_layer.non_trainable_weights +
self.backward_layer.non_trainable_weights)
return []

@property
Expand All @@ -253,11 +260,11 @@ def losses(self):

@property
def constraints(self):
_constraints = {}
constraints = {}
if hasattr(self.forward_layer, 'constraints'):
_constraints.update(self.forward_layer.constraints)
_constraints.update(self.backward_layer.constraints)
return _constraints
constraints.update(self.forward_layer.constraints)
constraints.update(self.backward_layer.constraints)
return constraints

def get_config(self):
config = {"merge_mode": self.merge_mode}
Expand Down

0 comments on commit 8ef4a3d

Please sign in to comment.