Skip to content

Commit

Permalink
Add option to allow loading of weights with mismatch. (keras-team#8462)
Browse files Browse the repository at this point in the history
* Add option to allow loading of weights with mismatch.

* Add unit test for skipping layers with mismatching weights.

* Add nose as travis dependency (required for unit tests).
  • Loading branch information
hgaiser authored and fchollet committed Dec 15, 2017
1 parent 04e0a10 commit 0611d80
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ install:
# Useful for debugging any issues with conda
- conda info -a

- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy scipy matplotlib pandas pytest h5py
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION numpy nose scipy matplotlib pandas pytest h5py
- source activate test-environment
- conda install mkl mkl-service
- pip install theano
Expand Down
49 changes: 37 additions & 12 deletions keras/engine/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -2597,7 +2597,7 @@ def save_weights(self, filepath, overwrite=True):
f.flush()
f.close()

def load_weights(self, filepath, by_name=False):
def load_weights(self, filepath, by_name=False, skip_mismatch=False):
"""Loads all layer weights from a HDF5 save file.
If `by_name` is False (default) weights are loaded
Expand All @@ -2616,6 +2616,11 @@ def load_weights(self, filepath, by_name=False):
filepath: String, path to the weights file to load.
by_name: Boolean, whether to load weights by name
or by topological order.
skip_mismatch: Boolean, whether to skip loading of layers
where there is a mismatch in the number of weights,
or a mismatch in the shape of the weight
(only valid when `by_name`=True).
# Raises
ImportError: If h5py is not available.
Expand All @@ -2626,7 +2631,8 @@ def load_weights(self, filepath, by_name=False):
if 'layer_names' not in f.attrs and 'model_weights' in f:
f = f['model_weights']
if by_name:
load_weights_from_hdf5_group_by_name(f, self.layers)
load_weights_from_hdf5_group_by_name(
f, self.layers, skip_mismatch=skip_mismatch)
else:
load_weights_from_hdf5_group(f, self.layers)

Expand Down Expand Up @@ -3152,7 +3158,7 @@ def load_weights_from_hdf5_group(f, layers):
K.batch_set_value(weight_value_tuples)


def load_weights_from_hdf5_group_by_name(f, layers):
def load_weights_from_hdf5_group_by_name(f, layers, skip_mismatch=False):
"""Implements name-based weight loading.
(instead of topological weight loading).
Expand All @@ -3161,11 +3167,14 @@ def load_weights_from_hdf5_group_by_name(f, layers):
# Arguments
f: A pointer to a HDF5 group.
layers: a list of target layers.
layers: A list of target layers.
skip_mismatch: Boolean, whether to skip loading of layers
where there is a mismatch in the number of weights,
or a mismatch in the shape of the weights.
# Raises
ValueError: in case of mismatch between provided layers
and weights file.
and weights file and skip_mismatch=False.
"""
if 'keras_version' in f.attrs:
original_keras_version = f.attrs['keras_version'].decode('utf8')
Expand Down Expand Up @@ -3201,15 +3210,31 @@ def load_weights_from_hdf5_group_by_name(f, layers):
original_keras_version,
original_backend)
if len(weight_values) != len(symbolic_weights):
raise ValueError('Layer #' + str(k) +
' (named "' + layer.name +
'") expects ' +
str(len(symbolic_weights)) +
' weight(s), but the saved weights' +
' have ' + str(len(weight_values)) +
' element(s).')
if skip_mismatch:
warnings.warn('Skipping loading of weights for layer {}'.format(layer.name) +
' due to mismatch in number of weights' +
' ({} vs {}).'.format(len(symbolic_weights), len(weight_values)))
continue
else:
raise ValueError('Layer #' + str(k) +
' (named "' + layer.name +
'") expects ' +
str(len(symbolic_weights)) +
' weight(s), but the saved weights' +
' have ' + str(len(weight_values)) +
' element(s).')
# Set values.
for i in range(len(weight_values)):
if skip_mismatch:
if K.int_shape(symbolic_weights[i]) != weight_values[i].shape:
warnings.warn('Skipping loading of weights for layer {}'.format(layer.name) +
' due to mismatch in shape' +
' ({} vs {}).'.format(
symbolic_weights[i].shape,
weight_values[i].shape))
continue

weight_value_tuples.append((symbolic_weights[i],
weight_values[i]))

K.batch_set_value(weight_value_tuples)
5 changes: 3 additions & 2 deletions keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def set_weights(self, weights):
self.build()
self.model.set_weights(weights)

def load_weights(self, filepath, by_name=False):
def load_weights(self, filepath, by_name=False, skip_mismatch=False):
if h5py is None:
raise ImportError('`load_weights` requires h5py.')
f = h5py.File(filepath, mode='r')
Expand All @@ -740,7 +740,8 @@ def load_weights(self, filepath, by_name=False):
else:
layers = self.layers
if by_name:
topology.load_weights_from_hdf5_group_by_name(f, layers)
topology.load_weights_from_hdf5_group_by_name(f, layers,
skip_mismatch=skip_mismatch)
else:
topology.load_weights_from_hdf5_group(f, layers)
if hasattr(f, 'close'):
Expand Down
50 changes: 49 additions & 1 deletion tests/test_model_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import tempfile
import numpy as np
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_raises

from keras import backend as K
from keras.models import Model, Sequential
Expand Down Expand Up @@ -277,6 +277,54 @@ def test_loading_weights_by_name_2():
assert_allclose(np.zeros_like(jessica[1]), jessica[1]) # biases init to 0


@keras_test
def test_loading_weights_by_name_skip_mismatch():
"""
test skipping layers while loading model weights by name on:
- sequential model
"""

# test with custom optimizer, loss
custom_opt = optimizers.rmsprop
custom_loss = losses.mse

# sequential model
model = Sequential()
model.add(Dense(2, input_shape=(3,), name='rick'))
model.add(Dense(3, name='morty'))
model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc'])

x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)

out = model.predict(x)
old_weights = [layer.get_weights() for layer in model.layers]
_, fname = tempfile.mkstemp('.h5')

model.save_weights(fname)

# delete and recreate model
del(model)
model = Sequential()
model.add(Dense(2, input_shape=(3,), name='rick'))
model.add(Dense(4, name='morty')) # different shape w.r.t. previous model
model.compile(loss=custom_loss, optimizer=custom_opt(), metrics=['acc'])

# load weights from first model
with pytest.warns(UserWarning): # expect UserWarning for skipping weights
model.load_weights(fname, by_name=True, skip_mismatch=True)
os.remove(fname)

# assert layers 'rick' are equal
for old, new in zip(old_weights[0], model.layers[0].get_weights()):
assert_allclose(old, new, atol=1e-05)

# assert layers 'morty' are not equal, since we skipped loading this layer
for old, new in zip(old_weights[1], model.layers[1].get_weights()):
assert_raises(AssertionError, assert_allclose, old, new, atol=1e-05)


# a function to be called from the Lambda layer
def square_fn(x):
return x * x
Expand Down

0 comments on commit 0611d80

Please sign in to comment.