forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_net2net.py
389 lines (338 loc) · 16.1 KB
/
mnist_net2net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
'''This is an implementation of Net2Net experiment with MNIST in
'Net2Net: Accelerating Learning via Knowledge Transfer'
by Tianqi Chen, Ian Goodfellow, and Jonathon Shlens
arXiv:1511.05641v4 [cs.LG] 23 Apr 2016
http://arxiv.org/abs/1511.05641
Notes
- What:
+ Net2Net is a group of methods to transfer knowledge from a teacher neural
net to a student net,so that the student net can be trained faster than
from scratch.
+ The paper discussed two specific methods of Net2Net, i.e. Net2WiderNet
and Net2DeeperNet.
+ Net2WiderNet replaces a model with an equivalent wider model that has
more units in each hidden layer.
+ Net2DeeperNet replaces a model with an equivalent deeper model.
+ Both are based on the idea of 'function-preserving transformations of
neural nets'.
- Why:
+ Enable fast exploration of multiple neural nets in experimentation and
design process,by creating a series of wider and deeper models with
transferable knowledge.
+ Enable 'lifelong learning system' by gradually adjusting model complexity
to data availability,and reusing transferable knowledge.
Experiments
- Teacher model: a basic CNN model trained on MNIST for 3 epochs.
- Net2WiderNet experiment:
+ Student model has a wider Conv2D layer and a wider FC layer.
+ Comparison of 'random-padding' vs 'net2wider' weight initialization.
+ With both methods, student model should immediately perform as well as
teacher model, but 'net2wider' is slightly better.
- Net2DeeperNet experiment:
+ Student model has an extra Conv2D layer and an extra FC layer.
+ Comparison of 'random-init' vs 'net2deeper' weight initialization.
+ Starting performance of 'net2deeper' is better than 'random-init'.
- Hyper-parameters:
+ SGD with momentum=0.9 is used for training teacher and student models.
+ Learning rate adjustment: it's suggested to reduce learning rate
to 1/10 for student model.
+ Addition of noise in 'net2wider' is used to break weight symmetry
and thus enable full capacity of student models. It is optional
when a Dropout layer is used.
Results
- Tested with 'Theano' backend and 'channels_first' image_data_format.
- Running on GPU GeForce GTX 980M
- Performance Comparisons - validation loss values during first 3 epochs:
(1) teacher_model: 0.075 0.041 0.041
(2) wider_random_pad: 0.036 0.034 0.032
(3) wider_net2wider: 0.032 0.030 0.030
(4) deeper_random_init: 0.061 0.043 0.041
(5) deeper_net2deeper: 0.032 0.031 0.029
'''
from __future__ import print_function
from six.moves import xrange
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Dense, Flatten
from keras.optimizers import SGD
from keras.datasets import mnist
if keras.backend.image_data_format() == 'channels_first':
input_shape = (1, 28, 28) # image shape
else:
input_shape = (28, 28, 1) # image shape
num_class = 10 # number of class
# load and pre-process data
def preprocess_input(x):
return x.reshape((-1, ) + input_shape) / 255.
def preprocess_output(y):
return keras.utils.to_categorical(y)
(train_x, train_y), (validation_x, validation_y) = mnist.load_data()
train_x, validation_x = map(preprocess_input, [train_x, validation_x])
train_y, validation_y = map(preprocess_output, [train_y, validation_y])
print('Loading MNIST data...')
print('train_x shape:', train_x.shape, 'train_y shape:', train_y.shape)
print('validation_x shape:', validation_x.shape,
'validation_y shape', validation_y.shape)
# knowledge transfer algorithms
def wider2net_conv2d(teacher_w1, teacher_b1, teacher_w2, new_width, init):
'''Get initial weights for a wider conv2d layer with a bigger filters,
by 'random-padding' or 'net2wider'.
# Arguments
teacher_w1: `weight` of conv2d layer to become wider,
of shape (filters1, num_channel1, kh1, kw1)
teacher_b1: `bias` of conv2d layer to become wider,
of shape (filters1, )
teacher_w2: `weight` of next connected conv2d layer,
of shape (filters2, num_channel2, kh2, kw2)
new_width: new `filters` for the wider conv2d layer
init: initialization algorithm for new weights,
either 'random-pad' or 'net2wider'
'''
assert teacher_w1.shape[0] == teacher_w2.shape[1], (
'successive layers from teacher model should have compatible shapes')
assert teacher_w1.shape[0] == teacher_b1.shape[0], (
'weight and bias from same layer should have compatible shapes')
assert new_width > teacher_w1.shape[0], (
'new width (filters) should be bigger than the existing one')
n = new_width - teacher_w1.shape[0]
if init == 'random-pad':
new_w1 = np.random.normal(0, 0.1, size=(n, ) + teacher_w1.shape[1:])
new_b1 = np.ones(n) * 0.1
new_w2 = np.random.normal(0, 0.1, size=(
teacher_w2.shape[0], n) + teacher_w2.shape[2:])
elif init == 'net2wider':
index = np.random.randint(teacher_w1.shape[0], size=n)
factors = np.bincount(index)[index] + 1.
new_w1 = teacher_w1[index, :, :, :]
new_b1 = teacher_b1[index]
new_w2 = teacher_w2[:, index, :, :] / factors.reshape((1, -1, 1, 1))
else:
raise ValueError('Unsupported weight initializer: %s' % init)
student_w1 = np.concatenate((teacher_w1, new_w1), axis=0)
if init == 'random-pad':
student_w2 = np.concatenate((teacher_w2, new_w2), axis=1)
elif init == 'net2wider':
# add small noise to break symmetry, so that student model will have
# full capacity later
noise = np.random.normal(0, 5e-2 * new_w2.std(), size=new_w2.shape)
student_w2 = np.concatenate((teacher_w2, new_w2 + noise), axis=1)
student_w2[:, index, :, :] = new_w2
student_b1 = np.concatenate((teacher_b1, new_b1), axis=0)
return student_w1, student_b1, student_w2
def wider2net_fc(teacher_w1, teacher_b1, teacher_w2, new_width, init):
'''Get initial weights for a wider fully connected (dense) layer
with a bigger nout, by 'random-padding' or 'net2wider'.
# Arguments
teacher_w1: `weight` of fc layer to become wider,
of shape (nin1, nout1)
teacher_b1: `bias` of fc layer to become wider,
of shape (nout1, )
teacher_w2: `weight` of next connected fc layer,
of shape (nin2, nout2)
new_width: new `nout` for the wider fc layer
init: initialization algorithm for new weights,
either 'random-pad' or 'net2wider'
'''
assert teacher_w1.shape[1] == teacher_w2.shape[0], (
'successive layers from teacher model should have compatible shapes')
assert teacher_w1.shape[1] == teacher_b1.shape[0], (
'weight and bias from same layer should have compatible shapes')
assert new_width > teacher_w1.shape[1], (
'new width (nout) should be bigger than the existing one')
n = new_width - teacher_w1.shape[1]
if init == 'random-pad':
new_w1 = np.random.normal(0, 0.1, size=(teacher_w1.shape[0], n))
new_b1 = np.ones(n) * 0.1
new_w2 = np.random.normal(0, 0.1, size=(n, teacher_w2.shape[1]))
elif init == 'net2wider':
index = np.random.randint(teacher_w1.shape[1], size=n)
factors = np.bincount(index)[index] + 1.
new_w1 = teacher_w1[:, index]
new_b1 = teacher_b1[index]
new_w2 = teacher_w2[index, :] / factors[:, np.newaxis]
else:
raise ValueError('Unsupported weight initializer: %s' % init)
student_w1 = np.concatenate((teacher_w1, new_w1), axis=1)
if init == 'random-pad':
student_w2 = np.concatenate((teacher_w2, new_w2), axis=0)
elif init == 'net2wider':
# add small noise to break symmetry, so that student model will have
# full capacity later
noise = np.random.normal(0, 5e-2 * new_w2.std(), size=new_w2.shape)
student_w2 = np.concatenate((teacher_w2, new_w2 + noise), axis=0)
student_w2[index, :] = new_w2
student_b1 = np.concatenate((teacher_b1, new_b1), axis=0)
return student_w1, student_b1, student_w2
def deeper2net_conv2d(teacher_w):
'''Get initial weights for a deeper conv2d layer by net2deeper'.
# Arguments
teacher_w: `weight` of previous conv2d layer,
of shape (filters, num_channel, kh, kw)
'''
filters, num_channel, kh, kw = teacher_w.shape
student_w = np.zeros((filters, filters, kh, kw))
for i in xrange(filters):
student_w[i, i, (kh - 1) / 2, (kw - 1) / 2] = 1.
student_b = np.zeros(filters)
return student_w, student_b
def copy_weights(teacher_model, student_model, layer_names):
'''Copy weights from teacher_model to student_model,
for layers with names listed in layer_names
'''
for name in layer_names:
weights = teacher_model.get_layer(name=name).get_weights()
student_model.get_layer(name=name).set_weights(weights)
# methods to construct teacher_model and student_models
def make_teacher_model(train_data, validation_data, epochs=3):
'''Train a simple CNN as teacher model.
'''
model = Sequential()
model.add(Conv2D(64, 3, input_shape=input_shape,
padding='same', name='conv1'))
model.add(MaxPooling2D(2, name='pool1'))
model.add(Conv2D(64, 3, padding='same', name='conv2'))
model.add(MaxPooling2D(2, name='pool2'))
model.add(Flatten(name='flatten'))
model.add(Dense(64, activation='relu', name='fc1'))
model.add(Dense(num_class, activation='softmax', name='fc2'))
model.compile(loss='categorical_crossentropy',
optimizer=SGD(lr=0.01, momentum=0.9),
metrics=['accuracy'])
train_x, train_y = train_data
history = model.fit(train_x, train_y,
epochs=epochs,
validation_data=validation_data)
return model, history
def make_wider_student_model(teacher_model, train_data,
validation_data, init, epochs=3):
'''Train a wider student model based on teacher_model,
with either 'random-pad' (baseline) or 'net2wider'
'''
new_conv1_width = 128
new_fc1_width = 128
model = Sequential()
# a wider conv1 compared to teacher_model
model.add(Conv2D(new_conv1_width, 3, input_shape=input_shape,
padding='same', name='conv1'))
model.add(MaxPooling2D(2, name='pool1'))
model.add(Conv2D(64, 3, padding='same', name='conv2'))
model.add(MaxPooling2D(2, name='pool2'))
model.add(Flatten(name='flatten'))
# a wider fc1 compared to teacher model
model.add(Dense(new_fc1_width, activation='relu', name='fc1'))
model.add(Dense(num_class, activation='softmax', name='fc2'))
# The weights for other layers need to be copied from teacher_model
# to student_model, except for widened layers
# and their immediate downstreams, which will be initialized separately.
# For this example there are no other layers that need to be copied.
w_conv1, b_conv1 = teacher_model.get_layer('conv1').get_weights()
w_conv2, b_conv2 = teacher_model.get_layer('conv2').get_weights()
new_w_conv1, new_b_conv1, new_w_conv2 = wider2net_conv2d(
w_conv1, b_conv1, w_conv2, new_conv1_width, init)
model.get_layer('conv1').set_weights([new_w_conv1, new_b_conv1])
model.get_layer('conv2').set_weights([new_w_conv2, b_conv2])
w_fc1, b_fc1 = teacher_model.get_layer('fc1').get_weights()
w_fc2, b_fc2 = teacher_model.get_layer('fc2').get_weights()
new_w_fc1, new_b_fc1, new_w_fc2 = wider2net_fc(
w_fc1, b_fc1, w_fc2, new_fc1_width, init)
model.get_layer('fc1').set_weights([new_w_fc1, new_b_fc1])
model.get_layer('fc2').set_weights([new_w_fc2, b_fc2])
model.compile(loss='categorical_crossentropy',
optimizer=SGD(lr=0.001, momentum=0.9),
metrics=['accuracy'])
train_x, train_y = train_data
history = model.fit(train_x, train_y,
epochs=epochs,
validation_data=validation_data)
return model, history
def make_deeper_student_model(teacher_model, train_data,
validation_data, init, epochs=3):
'''Train a deeper student model based on teacher_model,
with either 'random-init' (baseline) or 'net2deeper'
'''
model = Sequential()
model.add(Conv2D(64, 3, input_shape=input_shape,
padding='same', name='conv1'))
model.add(MaxPooling2D(2, name='pool1'))
model.add(Conv2D(64, 3, padding='same', name='conv2'))
# add another conv2d layer to make original conv2 deeper
if init == 'net2deeper':
prev_w, _ = model.get_layer('conv2').get_weights()
new_weights = deeper2net_conv2d(prev_w)
model.add(Conv2D(64, 3, padding='same',
name='conv2-deeper', weights=new_weights))
elif init == 'random-init':
model.add(Conv2D(64, 3, padding='same', name='conv2-deeper'))
else:
raise ValueError('Unsupported weight initializer: %s' % init)
model.add(MaxPooling2D(2, name='pool2'))
model.add(Flatten(name='flatten'))
model.add(Dense(64, activation='relu', name='fc1'))
# add another fc layer to make original fc1 deeper
if init == 'net2deeper':
# net2deeper for fc layer with relu, is just an identity initializer
model.add(Dense(64, kernel_initializer='identity',
activation='relu', name='fc1-deeper'))
elif init == 'random-init':
model.add(Dense(64, activation='relu', name='fc1-deeper'))
else:
raise ValueError('Unsupported weight initializer: %s' % init)
model.add(Dense(num_class, activation='softmax', name='fc2'))
# copy weights for other layers
copy_weights(teacher_model, model, layer_names=[
'conv1', 'conv2', 'fc1', 'fc2'])
model.compile(loss='categorical_crossentropy',
optimizer=SGD(lr=0.001, momentum=0.9),
metrics=['accuracy'])
train_x, train_y = train_data
history = model.fit(train_x, train_y,
epochs=epochs,
validation_data=validation_data)
return model, history
# experiments setup
def net2wider_experiment():
'''Benchmark performances of
(1) a teacher model,
(2) a wider student model with `random_pad` initializer
(3) a wider student model with `Net2WiderNet` initializer
'''
train_data = (train_x, train_y)
validation_data = (validation_x, validation_y)
print('\nExperiment of Net2WiderNet ...')
print('\nbuilding teacher model ...')
teacher_model, _ = make_teacher_model(train_data,
validation_data,
epochs=3)
print('\nbuilding wider student model by random padding ...')
make_wider_student_model(teacher_model, train_data,
validation_data, 'random-pad',
epochs=3)
print('\nbuilding wider student model by net2wider ...')
make_wider_student_model(teacher_model, train_data,
validation_data, 'net2wider',
epochs=3)
def net2deeper_experiment():
'''Benchmark performances of
(1) a teacher model,
(2) a deeper student model with `random_init` initializer
(3) a deeper student model with `Net2DeeperNet` initializer
'''
train_data = (train_x, train_y)
validation_data = (validation_x, validation_y)
print('\nExperiment of Net2DeeperNet ...')
print('\nbuilding teacher model ...')
teacher_model, _ = make_teacher_model(train_data,
validation_data,
epochs=3)
print('\nbuilding deeper student model by random init ...')
make_deeper_student_model(teacher_model, train_data,
validation_data, 'random-init',
epochs=3)
print('\nbuilding deeper student model by net2deeper ...')
make_deeper_student_model(teacher_model, train_data,
validation_data, 'net2deeper',
epochs=3)
# run the experiments
net2wider_experiment()
net2deeper_experiment()