@@ -62,7 +62,6 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
62
62
# normalize fully connected product
63
63
if norm :
64
64
# Batch Normalize
65
- # when testing, you should fix fc_mean, fc_var instead of using tf.nn.moments!
66
65
fc_mean , fc_var = tf .nn .moments (
67
66
Wx_plus_b ,
68
67
axes = [0 ], # the dimension you wanna normalize, here [0] for batch
@@ -71,7 +70,16 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
71
70
scale = tf .Variable (tf .ones ([out_size ]))
72
71
shift = tf .Variable (tf .zeros ([out_size ]))
73
72
epsilon = 0.001
74
- Wx_plus_b = tf .nn .batch_normalization (Wx_plus_b , fc_mean , fc_var , shift , scale , epsilon )
73
+
74
+ # apply moving average for mean and var when train on batch
75
+ ema = tf .train .ExponentialMovingAverage (decay = 0.5 )
76
+ def mean_var_with_update ():
77
+ ema_apply_op = ema .apply ([fc_mean , fc_var ])
78
+ with tf .control_dependencies ([ema_apply_op ]):
79
+ return tf .identity (fc_mean ), tf .identity (fc_var )
80
+ mean , var = mean_var_with_update ()
81
+
82
+ Wx_plus_b = tf .nn .batch_normalization (Wx_plus_b , mean , var , shift , scale , epsilon )
75
83
# similar with this two steps:
76
84
# Wx_plus_b = (Wx_plus_b - fc_mean) / tf.sqrt(fc_var + 0.001)
77
85
# Wx_plus_b = Wx_plus_b * scale + shift
@@ -95,7 +103,14 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
95
103
scale = tf .Variable (tf .ones ([1 ]))
96
104
shift = tf .Variable (tf .zeros ([1 ]))
97
105
epsilon = 0.001
98
- xs = tf .nn .batch_normalization (xs , fc_mean , fc_var , shift , scale , epsilon )
106
+ # apply moving average for mean and var when train on batch
107
+ ema = tf .train .ExponentialMovingAverage (decay = 0.5 )
108
+ def mean_var_with_update ():
109
+ ema_apply_op = ema .apply ([fc_mean , fc_var ])
110
+ with tf .control_dependencies ([ema_apply_op ]):
111
+ return tf .identity (fc_mean ), tf .identity (fc_var )
112
+ mean , var = mean_var_with_update ()
113
+ xs = tf .nn .batch_normalization (xs , mean , var , shift , scale , epsilon )
99
114
100
115
# record inputs for every layer
101
116
layers_inputs = [xs ]
@@ -123,7 +138,8 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
123
138
124
139
# make up data
125
140
fix_seed (1 )
126
- x_data = np .linspace (- 7 , 10 , 500 )[:, np .newaxis ]
141
+ x_data = np .linspace (- 7 , 10 , 2500 )[:, np .newaxis ]
142
+ np .random .shuffle (x_data )
127
143
noise = np .random .normal (0 , 8 , x_data .shape )
128
144
y_data = np .square (x_data ) - 5 + noise
129
145
@@ -147,13 +163,14 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
147
163
148
164
plt .ion ()
149
165
plt .figure (figsize = (7 , 3 ))
150
- for i in range (251 ):
166
+ for i in range (250 ):
151
167
if i % 50 == 0 :
152
168
# plot histogram
153
169
all_inputs , all_inputs_norm = sess .run ([layers_inputs , layers_inputs_norm ], feed_dict = {xs : x_data , ys : y_data })
154
170
plot_his (all_inputs , all_inputs_norm )
155
171
156
- sess .run ([train_op , train_op_norm ], feed_dict = {xs : x_data , ys : y_data })
172
+ # train on batch
173
+ sess .run ([train_op , train_op_norm ], feed_dict = {xs : x_data [i * 10 :i * 10 + 10 ], ys : y_data [i * 10 :i * 10 + 10 ]})
157
174
158
175
if i % record_step == 0 :
159
176
# record cost
@@ -167,6 +184,4 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
167
184
plt .legend ()
168
185
plt .show ()
169
186
170
- # when testing, you should fix fc_mean, fc_var instead of using tf.nn.moments!
171
-
172
187
0 commit comments