-
Notifications
You must be signed in to change notification settings - Fork 1
/
mnistm_model.py
146 lines (110 loc) · 5.47 KB
/
mnistm_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import tensorflow as tf
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
from IPython import display
# variable initialization functions
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='SAME')
class Model:
def __init__(self, x, y_):
self.x = x # input placeholder
self.X = x
self.y = y_
X_input = (tf.cast(self.X, tf.float32) / 255. ) - 0.5
# CNN model for feature extraction
with tf.variable_scope('feature_extractor'):
W_conv0 = weight_variable([5, 5, 3, 32])
b_conv0 = bias_variable([32])
h_conv0 = tf.nn.relu(conv2d(X_input, W_conv0) + b_conv0)
h_pool0 = max_pool_2x2(h_conv0)
W_conv1 = weight_variable([5, 5, 32, 48])
b_conv1 = bias_variable([48])
h_conv1 = tf.nn.relu(conv2d(h_pool0, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
self.feature = tf.reshape(h_pool1, [-1, 7*7*48])
# simple 3-layer network
W0 = weight_variable([7 * 7 * 48, 100])
b0 = bias_variable([100])
h0 = tf.nn.relu(tf.matmul(self.feature, W0) + b0)
W1 = weight_variable([100, 100])
b1 = bias_variable([100])
h1 = tf.nn.relu(tf.matmul(h0, W1) + b1)
W2 = weight_variable([100,10])
b2 = bias_variable([10])
self.y = tf.matmul(h1,W2) + b2 # output layer
self.var_list = [W_conv0, b_conv0, W_conv1, b_conv1, W0, b0, W1, b1, W2, b2]
# vanilla single-task loss
self.cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=self.y))
self.set_vanilla_loss()
# performance metrics
correct_prediction = tf.equal(tf.argmax(self.y,1), tf.argmax(y_,1))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
def compute_fisher(self, imgset, sess, num_samples=200, plot_diffs=False, disp_freq=10):
# computer Fisher information for each parameter
# initialize Fisher information for most recent task
self.F_accum = []
for v in range(len(self.var_list)):
self.F_accum.append(np.zeros(self.var_list[v].get_shape().as_list()))
# sampling a random class from softmax
probs = tf.nn.softmax(self.y)
class_ind = tf.to_int32(tf.multinomial(tf.log(probs), 1)[0][0])
if(plot_diffs):
# track differences in mean Fisher info
F_prev = deepcopy(self.F_accum)
mean_diffs = np.zeros(0)
for i in range(num_samples):
# select random input image
im_ind = np.random.randint(imgset.shape[0])
# compute first-order derivatives
ders = sess.run(tf.gradients(tf.log(probs[0,class_ind]), self.var_list), feed_dict={self.x: imgset[im_ind:im_ind+1]})
# square the derivatives and add to total
for v in range(len(self.F_accum)):
self.F_accum[v] += np.square(ders[v])
if(plot_diffs):
if i % disp_freq == 0 and i > 0:
# recording mean diffs of F
F_diff = 0
for v in range(len(self.F_accum)):
F_diff += np.sum(np.absolute(self.F_accum[v]/(i+1) - F_prev[v]))
mean_diff = np.mean(F_diff)
mean_diffs = np.append(mean_diffs, mean_diff)
for v in range(len(self.F_accum)):
F_prev[v] = self.F_accum[v]/(i+1)
plt.plot(range(disp_freq+1, i+2, disp_freq), mean_diffs)
plt.xlabel("Number of samples")
plt.ylabel("Mean absolute Fisher difference")
display.display(plt.gcf())
display.clear_output(wait=True)
# divide totals by number of samples
for v in range(len(self.F_accum)):
self.F_accum[v] /= num_samples
def star(self):
# used for saving optimal weights after most recent task training
self.star_vars = []
for v in range(len(self.var_list)):
self.star_vars.append(self.var_list[v].eval())
def restore(self, sess):
# reassign optimal weights for latest task
if hasattr(self, "star_vars"):
for v in range(len(self.var_list)):
sess.run(self.var_list[v].assign(self.star_vars[v]))
def set_vanilla_loss(self):
self.train_step = tf.train.GradientDescentOptimizer(0.1).minimize(self.cross_entropy)
def update_ewc_loss(self, lam):
# elastic weight consolidation
# lam is weighting for previous task(s) constraints
if not hasattr(self, "ewc_loss"):
self.ewc_loss = self.cross_entropy
for v in range(len(self.var_list)):
self.ewc_loss += (lam/2) * tf.reduce_sum(tf.multiply(self.F_accum[v].astype(np.float32),tf.square(self.var_list[v] - self.star_vars[v])))
self.train_step = tf.train.GradientDescentOptimizer(0.1).minimize(self.ewc_loss)