Skip to content

Commit

Permalink
added possibility to freeze layers during training
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikael Rousson committed Oct 23, 2015
1 parent f1ef989 commit a78c430
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 23 deletions.
114 changes: 114 additions & 0 deletions examples/mnist_transfer_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import absolute_import
from __future__ import print_function
import numpy as np
import datetime

np.random.seed(1337) # for reproducibility

from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.utils import np_utils

'''
Transfer learning toy example:
1- Train a simple convnet on the MNIST dataset the first 5 digits [0..4].
2- Freeze convolutional layers and fine-tune dense layers for the classification of digits [5..9].
Run on GPU: THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python mnist_cnn.py
Get to 99.8% test accuracy after 5 epochs for the first five digits classifier
and 99.2% for the last five digits after transfer + fine-tuning.
'''

now = datetime.datetime.now

batch_size = 128
nb_classes = 5
nb_epoch = 5

# input image dimensions
img_rows, img_cols = 28, 28
# number of convolutional filters to use
nb_filters = 32
# size of pooling area for max pooling
nb_pool = 2
# convolution kernel size
nb_conv = 3


def train_model(model, train, test, nb_classes):
X_train = train[0].reshape(train[0].shape[0], 1, img_rows, img_cols)
X_test = test[0].reshape(test[0].shape[0], 1, img_rows, img_cols)
X_train = X_train.astype("float32")
X_test = X_test.astype("float32")
X_train /= 255
X_test /= 255
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
Y_train = np_utils.to_categorical(train[1], nb_classes)
Y_test = np_utils.to_categorical(test[1], nb_classes)

model.compile(loss='categorical_crossentropy', optimizer='adadelta')

t = now()
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=1,
validation_data=(X_test, Y_test))
print('Training time: %s' % (now() - t))
score = model.evaluate(X_test, Y_test, show_accuracy=True, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])


# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# create two datasets one with digits below 5 and one with 5 and above
X_train_lt5 = X_train[y_train < 5]
y_train_lt5 = y_train[y_train < 5]
X_test_lt5 = X_test[y_test < 5]
y_test_lt5 = y_test[y_test < 5]

X_train_gte5 = X_train[y_train >= 5]
y_train_gte5 = y_train[y_train >= 5] - 5 # make classes start at 0 for
X_test_gte5 = X_test[y_test >= 5] # np_utils.to_categorical
y_test_gte5 = y_test[y_test >= 5] - 5

# define two groups of layers: feature (convolutions) and classification (dense)
feature_layers = [
Convolution2D(nb_filters, nb_conv, nb_conv,
border_mode='full',
input_shape=(1, img_rows, img_cols)),
Activation('relu'),
Convolution2D(nb_filters, nb_conv, nb_conv),
Activation('relu'),
MaxPooling2D(pool_size=(nb_pool, nb_pool)),
Dropout(0.25),
Flatten(),
]
classification_layers = [
Dense(128),
Activation('relu'),
Dropout(0.5),
Dense(nb_classes),
Activation('softmax')
]

# create complete model
model = Sequential()
for l in feature_layers + classification_layers:
model.add(l)

# train model for 5-digit classification [0..4]
train_model(model, (X_train_lt5, y_train_lt5), (X_test_lt5, y_test_lt5), nb_classes)

# freeze feature layers and rebuild model
for l in feature_layers:
l.trainable = False

# transfer: train dense layers for new classification task [5..9]
train_model(model, (X_train_gte5, y_train_gte5), (X_test_gte5, y_test_gte5), nb_classes)
87 changes: 65 additions & 22 deletions keras/layers/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import absolute_import
from __future__ import print_function

from collections import OrderedDict
import theano.tensor as T
from ..layers.core import Layer, Merge
from ..utils.theano_utils import ndim_tensor
Expand All @@ -20,11 +21,6 @@ class Sequential(Layer):

def __init__(self, layers=[]):
self.layers = []
self.params = []
self.regularizers = []
self.constraints = []
self.updates = []

for layer in layers:
self.add(layer)

Expand All @@ -38,11 +34,37 @@ def add(self, layer):
if not hasattr(self.layers[0], 'input'):
self.set_input()

params, regularizers, constraints, updates = layer.get_params()
self.params += params
self.regularizers += regularizers
self.constraints += constraints
self.updates += updates
@property
def params(self):
params = []
for l in self.layers:
if l.trainable:
params += l.get_params()[0]
return params

@property
def regularizers(self):
regularizers = []
for l in self.layers:
if l.trainable:
regularizers += l.get_params()[1]
return regularizers

@property
def constraints(self):
constraints = []
for l in self.layers:
if l.trainable:
constraints += l.get_params()[2]
return constraints

@property
def updates(self):
updates = []
for l in self.layers:
if l.trainable:
updates += l.get_params()[3]
return updates

@property
def output_shape(self):
Expand Down Expand Up @@ -97,15 +119,14 @@ class Graph(Layer):
when it has exactly one input and one output.
inherited from Layer:
- get_params
- get_output_mask
- supports_masked_input
- get_weights
- set_weights
'''
def __init__(self):
self.namespace = set() # strings
self.nodes = {} # layer-like
self.nodes = OrderedDict() # layer-like
self.inputs = {} # layer-like
self.input_order = [] # strings
self.outputs = {} # layer-like
Expand All @@ -114,11 +135,6 @@ def __init__(self):
self.output_config = [] # dicts
self.node_config = [] # dicts

self.params = []
self.regularizers = []
self.constraints = []
self.updates = []

@property
def nb_input(self):
return len(self.inputs)
Expand All @@ -127,6 +143,38 @@ def nb_input(self):
def nb_output(self):
return len(self.outputs)

@property
def params(self):
params = []
for l in self.nodes.values():
if l.trainable:
params += l.get_params()[0]
return params

@property
def regularizers(self):
regularizers = []
for l in self.nodes.values():
if l.trainable:
regularizers += l.get_params()[1]
return regularizers

@property
def constraints(self):
constraints = []
for l in self.nodes.values():
if l.trainable:
constraints += l.get_params()[2]
return constraints

@property
def updates(self):
updates = []
for l in self.nodes.values():
if l.trainable:
updates += l.get_params()[3]
return updates

def set_previous(self, layer, connection_map={}):
if self.nb_input != layer.nb_output:
raise Exception('Cannot connect layers: input count does not match output count.')
Expand Down Expand Up @@ -220,11 +268,6 @@ def add_node(self, layer, name, input=None, inputs=[],
'merge_mode': merge_mode,
'concat_axis': concat_axis,
'create_output': create_output})
params, regularizers, constraints, updates = layer.get_params()
self.params += params
self.regularizers += regularizers
self.constraints += constraints
self.updates += updates

if create_output:
self.add_output(name, input=name)
Expand Down
17 changes: 16 additions & 1 deletion keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
class Layer(object):
def __init__(self, **kwargs):
for kwarg in kwargs:
assert kwarg in {'input_shape'}, "Keyword argument not understood: " + kwarg
assert kwarg in {'input_shape', 'trainable'}, "Keyword argument not understood: " + kwarg
if 'input_shape' in kwargs:
self.set_input_shape(kwargs['input_shape'])
if 'trainable' in kwargs:
self._trainable = kwargs['trainable']
if not hasattr(self, 'params'):
self.params = []

Expand All @@ -45,6 +47,17 @@ def build(self):
'''
pass

@property
def trainable(self):
if hasattr(self, '_trainable'):
return self._trainable
else:
return True

@trainable.setter
def trainable(self, value):
self._trainable = value

@property
def nb_input(self):
return 1
Expand Down Expand Up @@ -133,6 +146,8 @@ def get_config(self):
config = {"name": self.__class__.__name__}
if hasattr(self, '_input_shape'):
config['input_shape'] = self._input_shape[1:]
if hasattr(self, '_trainable'):
config['trainable'] = self._trainable
return config

def get_params(self):
Expand Down

0 comments on commit a78c430

Please sign in to comment.