Skip to content

Commit 2d99738

Browse files
committed
add moving average
1 parent 9ecb734 commit 2d99738

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

tensorflowTUT/tf23_BN/tf23_BN.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
6262
# normalize fully connected product
6363
if norm:
6464
# Batch Normalize
65-
# when testing, you should fix fc_mean, fc_var instead of using tf.nn.moments!
6665
fc_mean, fc_var = tf.nn.moments(
6766
Wx_plus_b,
6867
axes=[0], # the dimension you wanna normalize, here [0] for batch
@@ -71,7 +70,16 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
7170
scale = tf.Variable(tf.ones([out_size]))
7271
shift = tf.Variable(tf.zeros([out_size]))
7372
epsilon = 0.001
74-
Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b, fc_mean, fc_var, shift, scale, epsilon)
73+
74+
# apply moving average for mean and var when train on batch
75+
ema = tf.train.ExponentialMovingAverage(decay=0.5)
76+
def mean_var_with_update():
77+
ema_apply_op = ema.apply([fc_mean, fc_var])
78+
with tf.control_dependencies([ema_apply_op]):
79+
return tf.identity(fc_mean), tf.identity(fc_var)
80+
mean, var = mean_var_with_update()
81+
82+
Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b, mean, var, shift, scale, epsilon)
7583
# similar with this two steps:
7684
# Wx_plus_b = (Wx_plus_b - fc_mean) / tf.sqrt(fc_var + 0.001)
7785
# Wx_plus_b = Wx_plus_b * scale + shift
@@ -95,7 +103,14 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
95103
scale = tf.Variable(tf.ones([1]))
96104
shift = tf.Variable(tf.zeros([1]))
97105
epsilon = 0.001
98-
xs = tf.nn.batch_normalization(xs, fc_mean, fc_var, shift, scale, epsilon)
106+
# apply moving average for mean and var when train on batch
107+
ema = tf.train.ExponentialMovingAverage(decay=0.5)
108+
def mean_var_with_update():
109+
ema_apply_op = ema.apply([fc_mean, fc_var])
110+
with tf.control_dependencies([ema_apply_op]):
111+
return tf.identity(fc_mean), tf.identity(fc_var)
112+
mean, var = mean_var_with_update()
113+
xs = tf.nn.batch_normalization(xs, mean, var, shift, scale, epsilon)
99114

100115
# record inputs for every layer
101116
layers_inputs = [xs]
@@ -123,7 +138,8 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
123138

124139
# make up data
125140
fix_seed(1)
126-
x_data = np.linspace(-7, 10, 500)[:, np.newaxis]
141+
x_data = np.linspace(-7, 10, 2500)[:, np.newaxis]
142+
np.random.shuffle(x_data)
127143
noise = np.random.normal(0, 8, x_data.shape)
128144
y_data = np.square(x_data) - 5 + noise
129145

@@ -147,13 +163,14 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
147163

148164
plt.ion()
149165
plt.figure(figsize=(7, 3))
150-
for i in range(251):
166+
for i in range(250):
151167
if i % 50 == 0:
152168
# plot histogram
153169
all_inputs, all_inputs_norm = sess.run([layers_inputs, layers_inputs_norm], feed_dict={xs: x_data, ys: y_data})
154170
plot_his(all_inputs, all_inputs_norm)
155171

156-
sess.run([train_op, train_op_norm], feed_dict={xs: x_data, ys: y_data})
172+
# train on batch
173+
sess.run([train_op, train_op_norm], feed_dict={xs: x_data[i*10:i*10+10], ys: y_data[i*10:i*10+10]})
157174

158175
if i % record_step == 0:
159176
# record cost
@@ -167,6 +184,4 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
167184
plt.legend()
168185
plt.show()
169186

170-
# when testing, you should fix fc_mean, fc_var instead of using tf.nn.moments!
171-
172187

0 commit comments

Comments
 (0)