Skip to content

Commit

Permalink
CuDNN RNN layers nested in TimeDistributed are not converted when loa…
Browse files Browse the repository at this point in the history
…ding (keras-team#10357)

* Add a unit test for CuDNNGRU conversion with TimeDistributed.

* Extract duplicated function convert_model() to _convert_model_weights().

* keras-team#10356 Convert weights of CuDNN/plain RNN nested in TimeDistributed.

Same case as for Bidirectional, except that in TimeDistributed there's only
one nested layer instead of two.
  • Loading branch information
bzamecnik authored and fchollet committed Jun 6, 2018
1 parent 13548e8 commit bbf4283
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 9 deletions.
14 changes: 13 additions & 1 deletion keras/engine/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,16 @@ def convert_nested_bidirectional(weights):
original_backend)
return forward_weights + backward_weights

def convert_nested_time_distributed(weights):
"""Converts layers nested in `TimeDistributed` wrapper by `preprocess_weights_for_loading()`.
# Arguments
weights: List of weights values (Numpy arrays).
# Returns
A list of weights values (Numpy arrays).
"""
return preprocess_weights_for_loading(layer.layer, weights, original_keras_version, original_backend)

def convert_nested_model(weights):
"""Converts layers nested in `Model` or `Sequential` by `preprocess_weights_for_loading()`.
Expand Down Expand Up @@ -535,11 +545,13 @@ def convert_nested_model(weights):
weights = weights[num_weights:]
return new_weights

# Convert layers nested in Bidirectional/Model/Sequential.
# Convert layers nested in Bidirectional/TimeDistributed/Model/Sequential.
# Both transformation should be ran for both Keras 1->2 conversion
# and for conversion of CuDNN layers.
if layer.__class__.__name__ == 'Bidirectional':
weights = convert_nested_bidirectional(weights)
if layer.__class__.__name__ == 'TimeDistributed':
weights = convert_nested_time_distributed(weights)
elif layer.__class__.__name__ in ['Model', 'Sequential']:
weights = convert_nested_model(weights)

Expand Down
64 changes: 56 additions & 8 deletions tests/test_model_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,12 +644,6 @@ def test_load_weights_between_noncudnn_rnn(rnn_type, to_cudnn, bidirectional, im
cudnn_rnn_layer_class = CuDNNGRU
rnn_layer_kwargs['reset_after'] = True

def convert_model(source_model, target_model):
_, fname = tempfile.mkstemp('.h5')
source_model.save_weights(fname)
target_model.load_weights(fname)
os.remove(fname)

layer = rnn_layer_class(units, **rnn_layer_kwargs)
if bidirectional:
layer = Bidirectional(layer)
Expand All @@ -662,9 +656,9 @@ def convert_model(source_model, target_model):
cudnn_model = _make_nested_model(input_shape, cudnn_layer, model_nest_level, model_type)

if to_cudnn:
convert_model(model, cudnn_model)
_convert_model_weights(model, cudnn_model)
else:
convert_model(cudnn_model, model)
_convert_model_weights(cudnn_model, model)

assert_allclose(model.predict(inputs), cudnn_model.predict(inputs), atol=1e-4)

Expand Down Expand Up @@ -692,6 +686,60 @@ def make_nested_func_model(input_shape, layer, level=1):
return make_nested_seq_model(input_shape, layer, level)


def _convert_model_weights(source_model, target_model):
_, fname = tempfile.mkstemp('.h5')
source_model.save_weights(fname)
target_model.load_weights(fname)
os.remove(fname)


@keras_test
@pytest.mark.parametrize('to_cudnn', [False, True], ids=['from_cudnn', 'to_cudnn'])
@pytest.mark.parametrize('rnn_type', ['LSTM', 'GRU'], ids=['LSTM', 'GRU'])
@skipif_no_tf_gpu
def test_load_weights_between_noncudnn_rnn_time_distributed(rnn_type, to_cudnn):
"""
Similar test as test_load_weights_between_noncudnn_rnn() but has different
rank of input due to usage of TimeDistributed. Issue: #10356.
"""
input_size = 10
steps = 6
timesteps = 6
input_shape = (timesteps, steps, input_size)
units = 2
num_samples = 32
inputs = np.random.random((num_samples,) + input_shape)

rnn_layer_kwargs = {
'recurrent_activation': 'sigmoid',
# ensure biases are non-zero and properly converted
'bias_initializer': 'random_uniform',
}
if rnn_type == 'LSTM':
rnn_layer_class = LSTM
cudnn_rnn_layer_class = CuDNNLSTM
else:
rnn_layer_class = GRU
cudnn_rnn_layer_class = CuDNNGRU
rnn_layer_kwargs['reset_after'] = True

layer = rnn_layer_class(units, **rnn_layer_kwargs)
layer = TimeDistributed(layer)

cudnn_layer = cudnn_rnn_layer_class(units)
cudnn_layer = TimeDistributed(cudnn_layer)

model = _make_nested_model(input_shape, layer)
cudnn_model = _make_nested_model(input_shape, cudnn_layer)

if to_cudnn:
_convert_model_weights(model, cudnn_model)
else:
_convert_model_weights(cudnn_model, model)

assert_allclose(model.predict(inputs), cudnn_model.predict(inputs), atol=1e-4)


@skipif_no_tf_gpu
def test_preprocess_weights_for_loading_gru_incompatible():
"""
Expand Down

0 comments on commit bbf4283

Please sign in to comment.