From e27b8b9343da4558b98570d6d45599bd0e365723 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Chollet?= Date: Sun, 13 Aug 2017 19:03:26 -0700 Subject: [PATCH] [RELNOTES] Add `clone_model`. (#7631) * Add `clone_model`. * Restrict use of `clear_session` in test to TF backend. * Fix naming issue --- keras/engine/topology.py | 3 +- keras/models.py | 230 ++++++++++++++++++++++++++- tests/keras/test_sequential_model.py | 79 +++++++++ 3 files changed, 310 insertions(+), 2 deletions(-) diff --git a/keras/engine/topology.py b/keras/engine/topology.py index 79ee3bb0f01..90ceb3110c8 100644 --- a/keras/engine/topology.py +++ b/keras/engine/topology.py @@ -1790,7 +1790,8 @@ def build_map_of_graph(tensor, finished_nodes, nodes_in_progress, raise RuntimeError('The name "' + name + '" is used ' + str(all_names.count(name)) + ' times in the model. ' - 'All layer names should be unique.') + 'All layer names should be unique. ' + 'Layer names: ', all_names) # Layer parameters. # The new container starts with a single inbound node diff --git a/keras/models.py b/keras/models.py index 818c002e57f..fa3e850191e 100644 --- a/keras/models.py +++ b/keras/models.py @@ -13,6 +13,7 @@ from . import optimizers from . import layers as layer_module from .utils.io_utils import ask_to_proceed_with_overwrite +from .utils.generic_utils import has_arg from .engine.training import Model from .engine import topology from .engine.topology import Layer @@ -560,7 +561,6 @@ def build(self, input_shape=None): # Make sure child model callbacks # will call the parent Sequential model. self.model.callback_model = self - self.built = True @property @@ -1299,3 +1299,231 @@ def get_or_create_layer(layer_data): layer = get_or_create_layer(conf) model.add(layer) return model + + +def _clone_functional_model(model, input_tensors=None): + """Clone a functional `Model` instance. + + Model cloning is similar to calling a model on new inputs, + except that it creates new layers (and thus new weights) instead + of sharing the weights of the existing layers. + + # Arguments + model: Instance of `Model`. + input_tensors: optional list of input tensors + to build the model upon. If not provided, + placeholders will be created. + + # Returns + An instance of `Model` reproducing the behavior + of the original model, on top of new inputs tensors, + using newly instantiated weights. + + # Raises + ValueError: in case of invalid `model` argument value. + """ + if not isinstance(model, Model): + raise ValueError('Expected `model` argument ' + 'to be a `Model` instance, got ', model) + if isinstance(model, Sequential): + raise ValueError('Expected `model` argument ' + 'to be a functional `Model` instance, ' + 'got a `Sequential` instance instead:', model) + + layer_map = {} # Cache for created layers. + tensor_map = {} # Map {reference_tensor: (corresponding_tensor, mask)} + if input_tensors is None: + # Create placeholders to build the model on top of. + input_layers = [] + input_tensors = [] + for layer in model.input_layers: + input_tensor = Input(batch_shape=layer.batch_input_shape, + dtype=layer.dtype, + sparse=layer.sparse, + name=layer.name) + input_tensors.append(input_tensor) + # Cache newly created input layer. + newly_created_input_layer = input_tensor._keras_history[0] + layer_map[layer] = newly_created_input_layer + for original_input_layer, cloned_input_layer in zip(model.input_layers, input_layers): + layer_map[original_input_layer] = cloned_input_layer + else: + # Make sure that all input tensors come from a Keras layer. + # If tensor comes from an input layer: cache the input layer. + input_tensors = topology._to_list(input_tensors) + _input_tensors = [] + for i, x in enumerate(input_tensors): + if not K.is_keras_tensor(x): + name = model.input_layers[i].name + input_tensor = Input(tensor=x, + name='input_wrapper_for_' + name) + _input_tensors.append(input_tensor) + # Cache newly created input layer. + original_input_layer = x._keras_history[0] + newly_created_input_layer = input_tensor._keras_history[0] + layer_map[original_input_layer] = newly_created_input_layer + else: + _input_tensors.append(x) + input_tensors = _input_tensors + + for x, y in zip(model.inputs, input_tensors): + tensor_map[x] = (y, None) # tensor, mask + + # Iterated over every node in the reference model, in depth order. + depth_keys = list(model.nodes_by_depth.keys()) + depth_keys.sort(reverse=True) + for depth in depth_keys: + nodes = model.nodes_by_depth[depth] + for node in nodes: + # Recover the corresponding layer. + layer = node.outbound_layer + + # Get or create layer. + if layer not in layer_map: + # Clone layer. + new_layer = layer.__class__.from_config(layer.get_config()) + layer_map[layer] = new_layer + layer = new_layer + else: + # Reuse previously cloned layer. + layer = layer_map[layer] + # Don't call InputLayer multiple times. + if isinstance(layer, topology.InputLayer): + continue + + # Gather inputs to call the new layer. + reference_input_tensors = node.input_tensors + reference_output_tensors = node.output_tensors + + # If all previous input tensors are available in tensor_map, + # then call node.inbound_layer on them. + computed_data = [] # List of tuples (input, mask). + for x in reference_input_tensors: + if x in tensor_map: + computed_data.append(tensor_map[x]) + + if len(computed_data) == len(reference_input_tensors): + # Call layer. + if node.arguments: + kwargs = node.arguments + else: + kwargs = {} + if len(computed_data) == 1: + computed_tensor, computed_mask = computed_data[0] + if has_arg(layer.call, 'mask'): + if 'mask' not in kwargs: + kwargs['mask'] = computed_mask + output_tensors = topology._to_list( + layer(computed_tensor, **kwargs)) + output_masks = topology._to_list( + layer.compute_mask(computed_tensor, + computed_mask)) + computed_tensors = [computed_tensor] + computed_masks = [computed_mask] + else: + computed_tensors = [x[0] for x in computed_data] + computed_masks = [x[1] for x in computed_data] + if has_arg(layer.call, 'mask'): + if 'mask' not in kwargs: + kwargs['mask'] = computed_masks + output_tensors = topology._to_list( + layer(computed_tensors, **kwargs)) + output_masks = topology._to_list( + layer.compute_mask(computed_tensors, + computed_masks)) + # Update tensor_map. + for x, y, mask in zip(reference_output_tensors, + output_tensors, + output_masks): + tensor_map[x] = (y, mask) + + # Check that we did compute the model outputs, + # then instantiate a new model from inputs and outputs. + output_tensors = [] + for x in model.outputs: + assert x in tensor_map, 'Could not compute output ' + str(x) + tensor, _ = tensor_map[x] + output_tensors.append(tensor) + return Model(input_tensors, output_tensors, name=model.name) + + +def _clone_sequential_model(model, input_tensors=None): + """Clone a `Sequential` model instance. + + Model cloning is similar to calling a model on new inputs, + except that it creates new layers (and thus new weights) instead + of sharing the weights of the existing layers. + + # Arguments + model: Instance of `Sequential`. + input_tensors: optional list of input tensors + to build the model upon. If not provided, + placeholders will be created. + + # Returns + An instance of `Sequential` reproducing the behavior + of the original model, on top of new inputs tensors, + using newly instantiated weights. + + # Raises + ValueError: in case of invalid `model` argument value. + """ + if not isinstance(model, Sequential): + raise ValueError('Expected `model` argument ' + 'to be a `Sequential` model instance, ' + 'but got:', model) + + def clone(layer): + return layer.__class__.from_config(layer.get_config()) + + layers = [clone(layer) for layer in model.layers] + if input_tensors is None: + return Sequential(layers=layers, name=model.name) + else: + if len(topology._to_list(input_tensors)) != 1: + raise ValueError('To clone a `Sequential` model, we expect ' + ' at most one tensor ' + 'as part of `input_tensors`.') + x = topology._to_list(input_tensors)[0] + if K.is_keras_tensor(x): + origin_layer = x._keras_history[0] + if isinstance(origin_layer, topology.InputLayer): + return Sequential(layers=[origin_layer] + layers, + name=model.name) + else: + raise ValueError('Cannot clone a `Sequential` model on top ' + 'of a tensor that comes from a Keras layer ' + 'other than an `InputLayer`. ' + 'Use the functional API instead.') + input_tensor = Input(tensor=x, + name='input_wrapper_for_' + str(x.name)) + input_layer = input_tensor._keras_history[0] + return Sequential(layers=[input_layer] + layers, name=model.name) + + +def clone_model(model, input_tensors=None): + """Clone any `Model` instance. + + Model cloning is similar to calling a model on new inputs, + except that it creates new layers (and thus new weights) instead + of sharing the weights of the existing layers. + + # Arguments + model: Instance of `Model` + (could be a functional model or a Sequential model). + input_tensors: optional list of input tensors + to build the model upon. If not provided, + placeholders will be created. + + # Returns + An instance of `Model` reproducing the behavior + of the original model, on top of new inputs tensors, + using newly instantiated weights. + + # Raises + ValueError: in case of invalid `model` argument value. + """ + if isinstance(model, Sequential): + return _clone_sequential_model(model, input_tensors=input_tensors) + else: + return _clone_functional_model(model, input_tensors=input_tensors) diff --git a/tests/keras/test_sequential_model.py b/tests/keras/test_sequential_model.py index c948f1e12c9..48ecd1d70c1 100644 --- a/tests/keras/test_sequential_model.py +++ b/tests/keras/test_sequential_model.py @@ -5,6 +5,7 @@ import numpy as np from keras import backend as K +import keras from keras.models import Sequential from keras.layers import Dense, Activation from keras.utils import np_utils @@ -267,5 +268,83 @@ def test_rebuild_model(): assert(model.get_layer(index=-1).output_shape == (None, 32)) +@keras_test +def test_clone_functional_model(): + val_a = np.random.random((10, 4)) + val_b = np.random.random((10, 4)) + val_out = np.random.random((10, 4)) + + input_a = keras.Input(shape=(4,)) + input_b = keras.Input(shape=(4,)) + dense_1 = keras.layers.Dense(4,) + dense_2 = keras.layers.Dense(4,) + + x_a = dense_1(input_a) + x_a = keras.layers.Dropout(0.5)(x_a) + x_b = dense_1(input_b) + x_a = dense_2(x_a) + outputs = keras.layers.add([x_a, x_b]) + model = keras.models.Model([input_a, input_b], outputs) + + if K.backend() == 'tensorflow': + # Everything should work in a new session. + K.clear_session() + + # With placeholder creation + new_model = keras.models.clone_model(model) + new_model.compile('rmsprop', 'mse') + new_model.train_on_batch([val_a, val_b], val_out) + + # On top of new tensors + input_a = keras.Input(shape=(4,), name='a') + input_b = keras.Input(shape=(4,), name='b') + new_model = keras.models.clone_model( + model, input_tensors=[input_a, input_b]) + new_model.compile('rmsprop', 'mse') + new_model.train_on_batch([val_a, val_b], val_out) + + # On top of new, non-Keras tensors + input_a = keras.backend.variable(val_a) + input_b = keras.backend.variable(val_b) + new_model = keras.models.clone_model( + model, input_tensors=[input_a, input_b]) + new_model.compile('rmsprop', 'mse') + new_model.train_on_batch(None, val_out) + + +@keras_test +def test_clone_sequential_model(): + val_a = np.random.random((10, 4)) + val_out = np.random.random((10, 4)) + + model = keras.models.Sequential() + model.add(keras.layers.Dense(4, input_shape=(4,))) + model.add(keras.layers.Dropout(0.5)) + model.add(keras.layers.Dense(4)) + + if K.backend() == 'tensorflow': + # Everything should work in a new session. + K.clear_session() + + # With placeholder creation + new_model = keras.models.clone_model(model) + new_model.compile('rmsprop', 'mse') + new_model.train_on_batch(val_a, val_out) + + # On top of new tensor + input_a = keras.Input(shape=(4,)) + new_model = keras.models.clone_model( + model, input_tensors=input_a) + new_model.compile('rmsprop', 'mse') + new_model.train_on_batch(val_a, val_out) + + # On top of new, non-Keras tensor + input_a = keras.backend.variable(val_a) + new_model = keras.models.clone_model( + model, input_tensors=input_a) + new_model.compile('rmsprop', 'mse') + new_model.train_on_batch(None, val_out) + + if __name__ == '__main__': pytest.main([__file__])