Skip to content

Commit

Permalink
fixed tensorflow bug
Browse files Browse the repository at this point in the history
  • Loading branch information
yaringal committed Feb 20, 2016
1 parent fed7cc2 commit 9648332
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 30 deletions.
2 changes: 1 addition & 1 deletion keras/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def get_output(self, train=False):
X = self.get_input(train)
retain_p = 1. - self.p
if train and self.p > 0:
B = K.random_binomial((self.input_dim), p=retain_p)
B = K.random_binomial((self.input_dim,), p=retain_p)
else:
B = K.ones((self.input_dim)) * retain_p
out = K.gather(self.W * K.expand_dims(B), X) # we zero-out rows of W at random
Expand Down
34 changes: 22 additions & 12 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ class Recurrent(MaskedLayer):
To reset the states of your model, call `.reset_states()` on either
a specific layer, or on your entire model.
# Note on using dropout with TensorFlow
When using the TensorFlow backend, specify a fixed batch size for your model
following the notes on statefulness RNNs.
'''
input_ndim = 3

Expand Down Expand Up @@ -252,8 +256,7 @@ def reset_states(self):
input_shape = self.input_shape
if not input_shape[0]:
raise Exception('If a RNN is stateful, a complete ' +
'input_shape must be provided ' +
'(including batch size).')
'input_shape must be provided (including batch size).')
if hasattr(self, 'states'):
K.set_value(self.states[0],
np.zeros((input_shape[0], self.output_dim)))
Expand All @@ -272,8 +275,11 @@ def step(self, x, states):

def get_constants(self, X, train=False):
nb_samples = K.shape(X)[0]
if K._BACKEND == 'tensorflow':
nb_samples = int(nb_samples)
if K._BACKEND == 'tensorflow' and train and self.p_W > 0 and self.p_U > 0:
if not self.input_shape[0]:
raise Exception('For RNN dropout in tensorflow, a complete ' +
'input_shape must be provided (including batch size).')
nb_samples = self.input_shape[0]
retain_p_W = 1. - self.p_W
retain_p_U = 1. - self.p_U
if train and self.p_W > 0 and self.p_U > 0:
Expand Down Expand Up @@ -391,8 +397,7 @@ def reset_states(self):
input_shape = self.input_shape
if not input_shape[0]:
raise Exception('If a RNN is stateful, a complete ' +
'input_shape must be provided ' +
'(including batch size).')
'input_shape must be provided (including batch size).')
if hasattr(self, 'states'):
K.set_value(self.states[0],
np.zeros((input_shape[0], self.output_dim)))
Expand All @@ -418,8 +423,11 @@ def step(self, x, states):

def get_constants(self, X, train=False):
nb_samples = K.shape(X)[0]
if K._BACKEND == 'tensorflow':
nb_samples = int(nb_samples)
if K._BACKEND == 'tensorflow' and train and self.p_W > 0 and self.p_U > 0:
if not self.input_shape[0]:
raise Exception('For RNN dropout in tensorflow, a complete ' +
'input_shape must be provided (including batch size).')
nb_samples = self.input_shape[0]
retain_p_W = 1. - self.p_W
retain_p_U = 1. - self.p_U
if train and self.p_W > 0 and self.p_U > 0:
Expand Down Expand Up @@ -553,8 +561,7 @@ def reset_states(self):
input_shape = self.input_shape
if not input_shape[0]:
raise Exception('If a RNN is stateful, a complete ' +
'input_shape must be provided ' +
'(including batch size).')
'input_shape must be provided (including batch size).')
if hasattr(self, 'states'):
K.set_value(self.states[0],
np.zeros((input_shape[0], self.output_dim)))
Expand Down Expand Up @@ -585,8 +592,11 @@ def step(self, x, states):

def get_constants(self, X, train=False):
nb_samples = K.shape(X)[0]
if K._BACKEND == 'tensorflow':
nb_samples = int(nb_samples)
if K._BACKEND == 'tensorflow' and train and self.p_W > 0 and self.p_U > 0:
if not self.input_shape[0]:
raise Exception('For RNN dropout in tensorflow, a complete ' +
'input_shape must be provided (including batch size).')
nb_samples = self.input_shape[0]
retain_p_W = 1. - self.p_W
retain_p_U = 1. - self.p_U
if train and self.p_W > 0 and self.p_U > 0:
Expand Down
50 changes: 33 additions & 17 deletions tests/keras/layers/test_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,39 @@ def _runner(layer_class):
All the recurrent layers share the same interface,
so we can run through them with a single function.
"""
for p in [0., 0.5]:
for ret_seq in [True, False]:
layer = layer_class(output_dim, return_sequences=ret_seq,
weights=None, input_shape=(timesteps, embedding_dim),
p_W=p, p_U=p)
layer.input = K.variable(np.ones((nb_samples, timesteps, embedding_dim)))
layer.get_config()

for train in [True, False]:
out = K.eval(layer.get_output(train))
# Make sure the output has the desired shape
if ret_seq:
assert(out.shape == (nb_samples, timesteps, output_dim))
else:
assert(out.shape == (nb_samples, output_dim))

mask = layer.get_output_mask(train)
for ret_seq in [True, False]:
layer = layer_class(output_dim, return_sequences=ret_seq,
weights=None, input_shape=(timesteps, embedding_dim))
layer.input = K.variable(np.ones((nb_samples, timesteps, embedding_dim)))
layer.get_config()

for train in [True, False]:
out = K.eval(layer.get_output(train))
# Make sure the output has the desired shape
if ret_seq:
assert(out.shape == (nb_samples, timesteps, output_dim))
else:
assert(out.shape == (nb_samples, output_dim))

mask = layer.get_output_mask(train)

# check dropout
for ret_seq in [True, False]:
layer = layer_class(output_dim, return_sequences=ret_seq, weights=None,
batch_input_shape=(nb_samples, timesteps, embedding_dim),
p_W=0.5, p_U=0.5)
layer.input = K.variable(np.ones((nb_samples, timesteps, embedding_dim)))
layer.get_config()

for train in [True, False]:
out = K.eval(layer.get_output(train))
# Make sure the output has the desired shape
if ret_seq:
assert(out.shape == (nb_samples, timesteps, output_dim))
else:
assert(out.shape == (nb_samples, output_dim))

mask = layer.get_output_mask(train)

# check statefulness
model = Sequential()
Expand Down

0 comments on commit 9648332

Please sign in to comment.