Skip to content

Commit 04dcbf1

Browse files
committedDec 5, 2016
tf BN
1 parent a9fc3ee commit 04dcbf1

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed
 

‎tensorflowTUT/tf23_BN/tf23_BN.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@ def plot_his(inputs, inputs_norm):
3232
for i, input in enumerate(all_inputs):
3333
plt.subplot(2, len(all_inputs), j*len(all_inputs)+(i+1))
3434
plt.cla()
35-
plt.hist(input.ravel(), bins=15, range=(-1, 1), color='#FF5733')
35+
if i == 0:
36+
the_range = (-7, 10)
37+
else:
38+
the_range = (-1, 1)
39+
plt.hist(input.ravel(), bins=15, range=the_range, color='#FF5733')
3640
plt.yticks(())
3741
if j == 1:
38-
plt.xticks((-1, 0, 1))
42+
plt.xticks(the_range)
3943
else:
4044
plt.xticks(())
4145
ax = plt.gca()
@@ -81,6 +85,18 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
8185
return outputs
8286

8387
fix_seed(1)
88+
89+
if norm:
90+
# BN for the first input
91+
fc_mean, fc_var = tf.nn.moments(
92+
xs,
93+
axes=[0],
94+
)
95+
scale = tf.Variable(tf.ones([1]))
96+
shift = tf.Variable(tf.zeros([1]))
97+
epsilon = 0.001
98+
xs = tf.nn.batch_normalization(xs, fc_mean, fc_var, shift, scale, epsilon)
99+
84100
# record inputs for every layer
85101
layers_inputs = [xs]
86102

@@ -137,8 +153,8 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
137153
all_inputs, all_inputs_norm = sess.run([layers_inputs, layers_inputs_norm], feed_dict={xs: x_data, ys: y_data})
138154
plot_his(all_inputs, all_inputs_norm)
139155

140-
sess.run(train_op, feed_dict={xs: x_data, ys: y_data})
141-
sess.run(train_op_norm, feed_dict={xs: x_data, ys: y_data})
156+
sess.run([train_op, train_op_norm], feed_dict={xs: x_data, ys: y_data})
157+
142158
if i % record_step == 0:
143159
# record cost
144160
cost_his.append(sess.run(cost, feed_dict={xs: x_data, ys: y_data}))

0 commit comments

Comments
 (0)