Skip to content

Commit

Permalink
[RELNOTES] Add clone_model. (keras-team#7631)
Browse files Browse the repository at this point in the history
* Add `clone_model`.

* Restrict use of `clear_session` in test to TF backend.

* Fix naming issue
  • Loading branch information
fchollet authored Aug 14, 2017
1 parent 57bc4da commit e27b8b9
Show file tree
Hide file tree
Showing 3 changed files with 310 additions and 2 deletions.
3 changes: 2 additions & 1 deletion keras/engine/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -1790,7 +1790,8 @@ def build_map_of_graph(tensor, finished_nodes, nodes_in_progress,
raise RuntimeError('The name "' + name + '" is used ' +
str(all_names.count(name)) +
' times in the model. '
'All layer names should be unique.')
'All layer names should be unique. '
'Layer names: ', all_names)

# Layer parameters.
# The new container starts with a single inbound node
Expand Down
230 changes: 229 additions & 1 deletion keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from . import optimizers
from . import layers as layer_module
from .utils.io_utils import ask_to_proceed_with_overwrite
from .utils.generic_utils import has_arg
from .engine.training import Model
from .engine import topology
from .engine.topology import Layer
Expand Down Expand Up @@ -560,7 +561,6 @@ def build(self, input_shape=None):
# Make sure child model callbacks
# will call the parent Sequential model.
self.model.callback_model = self

self.built = True

@property
Expand Down Expand Up @@ -1299,3 +1299,231 @@ def get_or_create_layer(layer_data):
layer = get_or_create_layer(conf)
model.add(layer)
return model


def _clone_functional_model(model, input_tensors=None):
"""Clone a functional `Model` instance.
Model cloning is similar to calling a model on new inputs,
except that it creates new layers (and thus new weights) instead
of sharing the weights of the existing layers.
# Arguments
model: Instance of `Model`.
input_tensors: optional list of input tensors
to build the model upon. If not provided,
placeholders will be created.
# Returns
An instance of `Model` reproducing the behavior
of the original model, on top of new inputs tensors,
using newly instantiated weights.
# Raises
ValueError: in case of invalid `model` argument value.
"""
if not isinstance(model, Model):
raise ValueError('Expected `model` argument '
'to be a `Model` instance, got ', model)
if isinstance(model, Sequential):
raise ValueError('Expected `model` argument '
'to be a functional `Model` instance, '
'got a `Sequential` instance instead:', model)

layer_map = {} # Cache for created layers.
tensor_map = {} # Map {reference_tensor: (corresponding_tensor, mask)}
if input_tensors is None:
# Create placeholders to build the model on top of.
input_layers = []
input_tensors = []
for layer in model.input_layers:
input_tensor = Input(batch_shape=layer.batch_input_shape,
dtype=layer.dtype,
sparse=layer.sparse,
name=layer.name)
input_tensors.append(input_tensor)
# Cache newly created input layer.
newly_created_input_layer = input_tensor._keras_history[0]
layer_map[layer] = newly_created_input_layer
for original_input_layer, cloned_input_layer in zip(model.input_layers, input_layers):
layer_map[original_input_layer] = cloned_input_layer
else:
# Make sure that all input tensors come from a Keras layer.
# If tensor comes from an input layer: cache the input layer.
input_tensors = topology._to_list(input_tensors)
_input_tensors = []
for i, x in enumerate(input_tensors):
if not K.is_keras_tensor(x):
name = model.input_layers[i].name
input_tensor = Input(tensor=x,
name='input_wrapper_for_' + name)
_input_tensors.append(input_tensor)
# Cache newly created input layer.
original_input_layer = x._keras_history[0]
newly_created_input_layer = input_tensor._keras_history[0]
layer_map[original_input_layer] = newly_created_input_layer
else:
_input_tensors.append(x)
input_tensors = _input_tensors

for x, y in zip(model.inputs, input_tensors):
tensor_map[x] = (y, None) # tensor, mask

# Iterated over every node in the reference model, in depth order.
depth_keys = list(model.nodes_by_depth.keys())
depth_keys.sort(reverse=True)
for depth in depth_keys:
nodes = model.nodes_by_depth[depth]
for node in nodes:
# Recover the corresponding layer.
layer = node.outbound_layer

# Get or create layer.
if layer not in layer_map:
# Clone layer.
new_layer = layer.__class__.from_config(layer.get_config())
layer_map[layer] = new_layer
layer = new_layer
else:
# Reuse previously cloned layer.
layer = layer_map[layer]
# Don't call InputLayer multiple times.
if isinstance(layer, topology.InputLayer):
continue

# Gather inputs to call the new layer.
reference_input_tensors = node.input_tensors
reference_output_tensors = node.output_tensors

# If all previous input tensors are available in tensor_map,
# then call node.inbound_layer on them.
computed_data = [] # List of tuples (input, mask).
for x in reference_input_tensors:
if x in tensor_map:
computed_data.append(tensor_map[x])

if len(computed_data) == len(reference_input_tensors):
# Call layer.
if node.arguments:
kwargs = node.arguments
else:
kwargs = {}
if len(computed_data) == 1:
computed_tensor, computed_mask = computed_data[0]
if has_arg(layer.call, 'mask'):
if 'mask' not in kwargs:
kwargs['mask'] = computed_mask
output_tensors = topology._to_list(
layer(computed_tensor, **kwargs))
output_masks = topology._to_list(
layer.compute_mask(computed_tensor,
computed_mask))
computed_tensors = [computed_tensor]
computed_masks = [computed_mask]
else:
computed_tensors = [x[0] for x in computed_data]
computed_masks = [x[1] for x in computed_data]
if has_arg(layer.call, 'mask'):
if 'mask' not in kwargs:
kwargs['mask'] = computed_masks
output_tensors = topology._to_list(
layer(computed_tensors, **kwargs))
output_masks = topology._to_list(
layer.compute_mask(computed_tensors,
computed_masks))
# Update tensor_map.
for x, y, mask in zip(reference_output_tensors,
output_tensors,
output_masks):
tensor_map[x] = (y, mask)

# Check that we did compute the model outputs,
# then instantiate a new model from inputs and outputs.
output_tensors = []
for x in model.outputs:
assert x in tensor_map, 'Could not compute output ' + str(x)
tensor, _ = tensor_map[x]
output_tensors.append(tensor)
return Model(input_tensors, output_tensors, name=model.name)


def _clone_sequential_model(model, input_tensors=None):
"""Clone a `Sequential` model instance.
Model cloning is similar to calling a model on new inputs,
except that it creates new layers (and thus new weights) instead
of sharing the weights of the existing layers.
# Arguments
model: Instance of `Sequential`.
input_tensors: optional list of input tensors
to build the model upon. If not provided,
placeholders will be created.
# Returns
An instance of `Sequential` reproducing the behavior
of the original model, on top of new inputs tensors,
using newly instantiated weights.
# Raises
ValueError: in case of invalid `model` argument value.
"""
if not isinstance(model, Sequential):
raise ValueError('Expected `model` argument '
'to be a `Sequential` model instance, '
'but got:', model)

def clone(layer):
return layer.__class__.from_config(layer.get_config())

layers = [clone(layer) for layer in model.layers]
if input_tensors is None:
return Sequential(layers=layers, name=model.name)
else:
if len(topology._to_list(input_tensors)) != 1:
raise ValueError('To clone a `Sequential` model, we expect '
' at most one tensor '
'as part of `input_tensors`.')
x = topology._to_list(input_tensors)[0]
if K.is_keras_tensor(x):
origin_layer = x._keras_history[0]
if isinstance(origin_layer, topology.InputLayer):
return Sequential(layers=[origin_layer] + layers,
name=model.name)
else:
raise ValueError('Cannot clone a `Sequential` model on top '
'of a tensor that comes from a Keras layer '
'other than an `InputLayer`. '
'Use the functional API instead.')
input_tensor = Input(tensor=x,
name='input_wrapper_for_' + str(x.name))
input_layer = input_tensor._keras_history[0]
return Sequential(layers=[input_layer] + layers, name=model.name)


def clone_model(model, input_tensors=None):
"""Clone any `Model` instance.
Model cloning is similar to calling a model on new inputs,
except that it creates new layers (and thus new weights) instead
of sharing the weights of the existing layers.
# Arguments
model: Instance of `Model`
(could be a functional model or a Sequential model).
input_tensors: optional list of input tensors
to build the model upon. If not provided,
placeholders will be created.
# Returns
An instance of `Model` reproducing the behavior
of the original model, on top of new inputs tensors,
using newly instantiated weights.
# Raises
ValueError: in case of invalid `model` argument value.
"""
if isinstance(model, Sequential):
return _clone_sequential_model(model, input_tensors=input_tensors)
else:
return _clone_functional_model(model, input_tensors=input_tensors)
79 changes: 79 additions & 0 deletions tests/keras/test_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

from keras import backend as K
import keras
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.utils import np_utils
Expand Down Expand Up @@ -267,5 +268,83 @@ def test_rebuild_model():
assert(model.get_layer(index=-1).output_shape == (None, 32))


@keras_test
def test_clone_functional_model():
val_a = np.random.random((10, 4))
val_b = np.random.random((10, 4))
val_out = np.random.random((10, 4))

input_a = keras.Input(shape=(4,))
input_b = keras.Input(shape=(4,))
dense_1 = keras.layers.Dense(4,)
dense_2 = keras.layers.Dense(4,)

x_a = dense_1(input_a)
x_a = keras.layers.Dropout(0.5)(x_a)
x_b = dense_1(input_b)
x_a = dense_2(x_a)
outputs = keras.layers.add([x_a, x_b])
model = keras.models.Model([input_a, input_b], outputs)

if K.backend() == 'tensorflow':
# Everything should work in a new session.
K.clear_session()

# With placeholder creation
new_model = keras.models.clone_model(model)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch([val_a, val_b], val_out)

# On top of new tensors
input_a = keras.Input(shape=(4,), name='a')
input_b = keras.Input(shape=(4,), name='b')
new_model = keras.models.clone_model(
model, input_tensors=[input_a, input_b])
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch([val_a, val_b], val_out)

# On top of new, non-Keras tensors
input_a = keras.backend.variable(val_a)
input_b = keras.backend.variable(val_b)
new_model = keras.models.clone_model(
model, input_tensors=[input_a, input_b])
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch(None, val_out)


@keras_test
def test_clone_sequential_model():
val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4))

model = keras.models.Sequential()
model.add(keras.layers.Dense(4, input_shape=(4,)))
model.add(keras.layers.Dropout(0.5))
model.add(keras.layers.Dense(4))

if K.backend() == 'tensorflow':
# Everything should work in a new session.
K.clear_session()

# With placeholder creation
new_model = keras.models.clone_model(model)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch(val_a, val_out)

# On top of new tensor
input_a = keras.Input(shape=(4,))
new_model = keras.models.clone_model(
model, input_tensors=input_a)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch(val_a, val_out)

# On top of new, non-Keras tensor
input_a = keras.backend.variable(val_a)
new_model = keras.models.clone_model(
model, input_tensors=input_a)
new_model.compile('rmsprop', 'mse')
new_model.train_on_batch(None, val_out)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit e27b8b9

Please sign in to comment.