Skip to content

Commit

Permalink
Do not assume outputs.dtype is equal to inputs.dtype in rnn() (tensor…
Browse files Browse the repository at this point in the history
…flow_backend.py) (keras-team#5715)

* Update tensorflow_backend.py

* Add TimeDistributed tests of Dense and Embedding layers with batch_input_size
  • Loading branch information
jarfo authored and fchollet committed Mar 11, 2017
1 parent 81aa60a commit fd9acf6
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
3 changes: 2 additions & 1 deletion keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2255,8 +2255,9 @@ def rnn(step_function, inputs, initial_states,
states = tuple(initial_states)

time_steps = tf.shape(inputs)[0]
outputs, _ = step_function(inputs[0], initial_states + constants)
output_ta = tensor_array_ops.TensorArray(
dtype=inputs.dtype,
dtype=outputs.dtype,
size=time_steps,
tensor_array_name='output_ta')
input_ta = tensor_array_ops.TensorArray(
Expand Down
35 changes: 34 additions & 1 deletion tests/keras/layers/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from numpy.testing import assert_allclose
from keras.utils.test_utils import keras_test
from keras.layers import wrappers, Input
from keras.layers import core, convolutional, recurrent
from keras.layers import core, convolutional, recurrent, embeddings
from keras.models import Sequential, Model, model_from_json


Expand All @@ -19,6 +19,39 @@ def test_TimeDistributed():
# test config
model.get_config()

# test when specifying a batch_input_shape
test_input = np.random.random((1, 3, 4))
test_output = model.predict(test_input)
weights = model.layers[0].get_weights()

reference = Sequential()
reference.add(wrappers.TimeDistributed(core.Dense(2), batch_input_shape=(1, 3, 4)))
reference.add(core.Activation('relu'))
reference.compile(optimizer='rmsprop', loss='mse')
reference.layers[0].set_weights(weights)

reference_output = reference.predict(test_input)
assert_allclose(test_output, reference_output, atol=1e-05)

# test with Embedding
model = Sequential()
model.add(wrappers.TimeDistributed(embeddings.Embedding(5, 6), batch_input_shape=(10, 3, 4), dtype='int32'))
model.compile(optimizer='rmsprop', loss='mse')
model.fit(np.random.randint(5, size=(10, 3, 4), dtype='int32'), np.random.random((10, 3, 4, 6)), epochs=1, batch_size=10)

# compare to not using batch_input_shape
test_input = np.random.randint(5, size=(10, 3, 4), dtype='int32')
test_output = model.predict(test_input)
weights = model.layers[0].get_weights()

reference = Sequential()
reference.add(wrappers.TimeDistributed(embeddings.Embedding(5, 6), input_shape=(3, 4), dtype='int32'))
reference.compile(optimizer='rmsprop', loss='mse')
reference.layers[0].set_weights(weights)

reference_output = reference.predict(test_input)
assert_allclose(test_output, reference_output, atol=1e-05)

# test with Conv2D
model = Sequential()
model.add(wrappers.TimeDistributed(convolutional.Conv2D(5, (2, 2), padding='same'), input_shape=(2, 4, 4, 3)))
Expand Down

0 comments on commit fd9acf6

Please sign in to comment.