Skip to content

Commit 095b6c1

Browse files
wxsfchollet
authored andcommitted
Fix merge conflicts
2 parents 827ec65 + c94cf4b commit 095b6c1

9 files changed

+65
-26
lines changed

keras/backend/tensorflow_backend.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,12 @@ def mean(x, axis=None, keepdims=False):
154154
def any(x, axis=None, keepdims=False):
155155
'''Bitwise reduction (logical OR).
156156
157-
Return array of int8 (0s and 1s).
157+
Return array of uint8 (0s and 1s).
158158
'''
159159
axis = normalize_axis(axis, ndim(x))
160160
x = tf.cast(x, tf.bool)
161161
x = tf.reduce_any(x, reduction_indices=axis, keep_dims=keepdims)
162-
return tf.cast(x, tf.int8)
162+
return tf.cast(x, tf.uint8)
163163

164164

165165
def argmax(x, axis=-1):
@@ -438,7 +438,10 @@ def rnn(step_function, inputs, initial_states,
438438

439439
if mask is not None:
440440
# Transpose not supported by bool tensor types, hence round-trip to uint8.
441-
mask = tf.cast(tf.transpose(tf.cast(mask, tf.uint8), axes), tf.bool)
441+
mask = tf.cast(mask, tf.uint8)
442+
if len(mask.get_shape()) == ndim-1:
443+
mask = expand_dims(mask)
444+
mask = tf.cast(tf.transpose(mask, axes), tf.bool)
442445
mask_list = tf.unpack(mask)
443446

444447
for input, mask_t in zip(input_list, mask_list):

keras/backend/theano_backend.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def rnn(step_function, inputs, initial_states,
427427
the step function.
428428
go_backwards: boolean. If True, do the iteration over
429429
the time dimension in reverse order.
430-
mask: binary tensor with shape (samples, time, 1),
430+
mask: binary tensor with shape (samples, time),
431431
with a zero for every element that is masked.
432432
433433
Returns
@@ -447,6 +447,9 @@ def rnn(step_function, inputs, initial_states,
447447
if mask is None:
448448
mask = expand_dims(ones_like(T.sum(inputs, axis=-1)))
449449
else:
450+
if mask.ndim == ndim-1:
451+
mask = expand_dims(mask)
452+
assert mask.ndim == ndim
450453
mask = mask.dimshuffle(axes)
451454

452455
def _step(input, mask, output_tm1, *states):
@@ -665,6 +668,7 @@ def pool2d(x, pool_size, strides=(1, 1), border_mode='valid',
665668
pool_out = pool_out.dimshuffle((0, 2, 3, 1))
666669
return pool_out
667670

671+
668672
# RANDOMNESS
669673

670674

keras/layers/core.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -323,17 +323,16 @@ class Masking(MaskedLayer):
323323
def __init__(self, mask_value=0., **kwargs):
324324
super(Masking, self).__init__(**kwargs)
325325
self.mask_value = mask_value
326-
self.input = K.placeholder(ndim=3)
326+
if (not hasattr(self, 'input')):
327+
self.input = K.placeholder(ndim=3)
327328

328329
def get_output_mask(self, train=False):
329330
X = self.get_input(train)
330-
return K.any(K.ones_like(X) * (1. - K.equal(X, self.mask_value)),
331-
axis=-1)
331+
return K.any(K.not_equal(X, self.mask_value), axis=-1)
332332

333333
def get_output(self, train=False):
334334
X = self.get_input(train)
335-
return X * K.any((1. - K.equal(X, self.mask_value)),
336-
axis=-1, keepdims=True)
335+
return X * K.cast(K.any(K.not_equal(X, self.mask_value), axis=-1, keepdims=True), K.floatx())
337336

338337
def get_config(self):
339338
config = {'name': self.__class__.__name__,

keras/layers/embeddings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def get_output_mask(self, train=None):
8989
if not self.mask_zero:
9090
return None
9191
else:
92-
return K.expand_dims(K.not_equal(X, 0))
92+
return K.not_equal(X, 0)
9393

9494
@property
9595
def output_shape(self):

keras/objectives.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def hinge(y_true, y_pred):
3737
def categorical_crossentropy(y_true, y_pred):
3838
'''Expects a binary class matrix instead of a vector of scalar classes.
3939
'''
40-
return K.mean(K.categorical_crossentropy(y_pred, y_true), axis=-1)
40+
return K.categorical_crossentropy(y_pred, y_true)
4141

4242

4343
def binary_crossentropy(y_true, y_pred):
@@ -49,11 +49,9 @@ def poisson(y_true, y_pred):
4949

5050

5151
def cosine_proximity(y_true, y_pred):
52-
assert K.ndim(y_true) == 2
53-
assert K.ndim(y_pred) == 2
54-
y_true = K.l2_normalize(y_true, axis=1)
55-
y_pred = K.l2_normalize(y_pred, axis=1)
56-
return -K.mean(y_true * y_pred, axis=1)
52+
y_true = K.l2_normalize(y_true, axis=-1)
53+
y_pred = K.l2_normalize(y_pred, axis=-1)
54+
return -K.mean(y_true * y_pred, axis=-1)
5755

5856

5957
# aliases

tests/keras/layers/test_core.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,11 @@ def test_naming():
172172
model.train_on_batch(np.random.random((2, 2)), np.random.random((2, 2)))
173173

174174

175-
@pytest.mark.skipif(K._BACKEND == 'tensorflow',
176-
reason='currently not working with TensorFlow')
177175
def test_sequences():
178176
'''Test masking sequences with zeroes as padding'''
179177
# integer inputs, one per timestep, like embeddings
180178
layer = core.Masking()
181-
func = K.function([layer.input], [layer.get_output_mask()])
179+
func = K.function([layer.get_input(True)], [layer.get_output_mask()])
182180
input_data = np.array([[[1], [2], [3], [0]],
183181
[[0], [4], [5], [0]]], dtype=np.int32)
184182

@@ -190,8 +188,6 @@ def test_sequences():
190188
assert np.all(output == expected), 'Output not as expected'
191189

192190

193-
@pytest.mark.skipif(K._BACKEND == 'tensorflow',
194-
reason='currently not working with TensorFlow')
195191
def test_non_zero():
196192
'''Test masking with non-zero mask value'''
197193
layer = core.Masking(5)
@@ -204,8 +200,6 @@ def test_non_zero():
204200
assert np.all(output == expected), 'Output not as expected'
205201

206202

207-
@pytest.mark.skipif(K._BACKEND == 'tensorflow',
208-
reason='currently not working with TensorFlow')
209203
def test_non_zero_output():
210204
'''Test output of masking layer with non-zero mask value'''
211205
layer = core.Masking(5)

tests/keras/layers/test_recurrent.py

+18
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from numpy.testing import assert_allclose
44

55
from keras.layers import recurrent, embeddings
6+
from keras.models import Sequential
7+
from keras.layers.core import Masking
8+
69
from keras import backend as K
710
from keras.models import Sequential, model_from_json
811

@@ -111,5 +114,20 @@ def test_batch_input_shape_serialization():
111114
assert(reconstructed_model.input_shape == (2, 2))
112115

113116

117+
def test_masking_layer():
118+
''' This test based on a previously failing issue here:
119+
https://github.com/fchollet/keras/issues/1567
120+
121+
'''
122+
model = Sequential()
123+
model.add(Masking(input_shape=(3, 4)))
124+
model.add(recurrent.LSTM(output_dim=5, return_sequences=True))
125+
model.compile(loss='categorical_crossentropy', optimizer='adam')
126+
I = np.random.random((6, 3, 4))
127+
V = np.abs(np.random.random((6, 3, 5)))
128+
V /= V.sum(axis=-1, keepdims=True)
129+
model.fit(I, V, nb_epoch=1, batch_size=100, verbose=1)
130+
131+
114132
if __name__ == '__main__':
115133
pytest.main([__file__])

tests/keras/test_objectives.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
3+
from keras import objectives
4+
from keras import backend as K
5+
6+
7+
allobj = [objectives.mean_squared_error, objectives.root_mean_squared_error,
8+
objectives.mean_absolute_error, objectives.mean_absolute_percentage_error,
9+
objectives.mean_squared_logarithmic_error, objectives.squared_hinge,
10+
objectives.hinge, objectives.categorical_crossentropy, objectives.binary_crossentropy, objectives.poisson,
11+
objectives.cosine_proximity]
12+
13+
def test_objective_shapes_3d():
14+
y_a = K.variable(np.random.random((5, 6, 7)))
15+
y_b = K.variable(np.random.random((5, 6, 7)))
16+
for obj in allobj:
17+
objective_output = obj(y_a, y_b)
18+
assert K.eval(objective_output).shape == (5, 6)
19+
20+
def test_objective_shapes_2d():
21+
y_a = K.variable(np.random.random((6, 7)))
22+
y_b = K.variable(np.random.random((6, 7)))
23+
for obj in allobj:
24+
objective_output = obj(y_a, y_b)
25+
assert K.eval(objective_output).shape == (6,)

tests/test_loss_masking.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,13 @@
77
from keras import backend as K
88

99

10-
@pytest.mark.skipif(K._BACKEND == 'tensorflow',
11-
reason='currently not working with TensorFlow')
1210
def test_masking():
1311
np.random.seed(1337)
1412
X = np.array(
1513
[[[1, 1], [2, 1], [3, 1], [5, 5]],
1614
[[1, 5], [5, 0], [0, 0], [0, 0]]], dtype=np.int32)
1715
model = Sequential()
18-
model.add(Masking(mask_value=0, input_shape=(None, 2)))
16+
model.add(Masking(mask_value=0, input_shape=(4, 2)))
1917
model.add(TimeDistributedDense(1, init='one'))
2018
model.compile(loss='mse', optimizer='sgd')
2119
y = model.predict(X)

0 commit comments

Comments
 (0)