Skip to content

Commit

Permalink
Lambda test speed up (keras-team#9601)
Browse files Browse the repository at this point in the history
* lambda test speed up

* typo
  • Loading branch information
farizrahman4u authored and fchollet committed Mar 9, 2018
1 parent 6249bd5 commit 853774a
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions tests/keras/layers/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,21 @@ def output_shape(input_shape):
def mask(inputs, mask=None):
return [None, None]

i = layers.Input(shape=(64, 64, 3))
i = layers.Input(shape=(3, 2, 1))
o = layers.Lambda(function=func,
output_shape=output_shape,
mask=mask)(i)

o1, o2 = o
assert o1._keras_shape == (None, 64, 64, 3)
assert o2._keras_shape == (None, 64, 64, 3)
assert o1._keras_shape == (None, 3, 2, 1)
assert o2._keras_shape == (None, 3, 2, 1)

model = Model(i, o)

x = np.random.random((4, 64, 64, 3))
x = np.random.random((4, 3, 2, 1))
out1, out2 = model.predict(x)
assert out1.shape == (4, 64, 64, 3)
assert out2.shape == (4, 64, 64, 3)
assert out1.shape == (4, 3, 2, 1)
assert out2.shape == (4, 3, 2, 1)
assert_allclose(out1, x * 0.2, atol=1e-4)
assert_allclose(out2, x * 0.3, atol=1e-4)

Expand All @@ -178,23 +178,23 @@ def func(x):
def output_shape(input_shape):
return [input_shape, input_shape]

i = layers.Input(shape=(64, 64, 3))
i = layers.Input(shape=(3, 2, 1))
o = layers.Lambda(function=func,
output_shape=output_shape)(i)

assert o[0]._keras_shape == (None, 64, 64, 3)
assert o[1]._keras_shape == (None, 64, 64, 3)
assert o[0]._keras_shape == (None, 3, 2, 1)
assert o[1]._keras_shape == (None, 3, 2, 1)

o = layers.add(o)
model = Model(i, o)

i2 = layers.Input(shape=(64, 64, 3))
i2 = layers.Input(shape=(3, 2, 1))
o2 = model(i2)
model2 = Model(i2, o2)

x = np.random.random((4, 64, 64, 3))
x = np.random.random((4, 3, 2, 1))
out = model2.predict(x)
assert out.shape == (4, 64, 64, 3)
assert out.shape == (4, 3, 2, 1)
assert_allclose(out, x * 0.2 + x * 0.3, atol=1e-4)

test_multiple_outputs_no_mask()
Expand Down

0 comments on commit 853774a

Please sign in to comment.