Skip to content

Commit 4c76458

Browse files
committed
tf BN
1 parent e76180e commit 4c76458

File tree

3 files changed

+306
-122
lines changed

3 files changed

+306
-122
lines changed

tensorflowTUT/tf23_BN.py

-122
This file was deleted.

tensorflowTUT/tf23_BN/test.py

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""
2+
visit https://morvanzhou.github.io/tutorials/ for more!
3+
4+
Build two networks.
5+
1. Without batch normalization
6+
2. With batch normalization
7+
8+
Run tests on these two networks.
9+
"""
10+
11+
# 23 Batch Normalization
12+
13+
import numpy as np
14+
import tensorflow as tf
15+
import matplotlib.pyplot as plt
16+
17+
18+
ACTIVATION = tf.nn.relu
19+
N_LAYERS = 7
20+
N_HIDDEN_UNITS = 30
21+
22+
23+
def fix_seed(seed=1):
24+
# reproducible
25+
np.random.seed(seed)
26+
tf.set_random_seed(seed)
27+
28+
29+
def plot_his(inputs, inputs_norm):
30+
# plot histogram for the inputs of every layer
31+
32+
for j, all_inputs in enumerate([inputs, inputs_norm]):
33+
for i, input in enumerate(all_inputs):
34+
plt.subplot(2, len(all_inputs), j*len(all_inputs)+(i+1))
35+
plt.cla()
36+
plt.hist(input.ravel(), bins=15, range=(-1, 1), color='#FF5733')
37+
plt.yticks(())
38+
if j == 1:
39+
plt.xticks((-1, 0, 1))
40+
else:
41+
plt.xticks(())
42+
ax = plt.gca()
43+
ax.spines['right'].set_color('none')
44+
ax.spines['top'].set_color('none')
45+
plt.title("%s normalizing" % ("Without" if j == 0 else "With"))
46+
plt.draw()
47+
plt.pause(0.01)
48+
49+
50+
def built_net(xs, ys, norm):
51+
def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
52+
# weights and biases (bad initialization for this case)
53+
Weights = tf.Variable(tf.random_normal([in_size, out_size], mean=0., stddev=1.))
54+
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
55+
56+
# fully connected product
57+
Wx_plus_b = tf.matmul(inputs, Weights) + biases
58+
59+
# normalize fully connected product
60+
if norm:
61+
# Batch Normalize
62+
fc_mean, fc_var = tf.nn.moments(
63+
Wx_plus_b,
64+
axes=[0], # the dimension you wanna normalize, here [0] for batch
65+
# for image, you wanna do [0, 1, 2] for [batch, height, width] but not channel
66+
)
67+
scale = tf.Variable(tf.ones([out_size]))
68+
shift = tf.Variable(tf.zeros([out_size]))
69+
epsilon = 0.001
70+
# similar with this:
71+
# Wx_plus_b = (Wx_plus_b - fc_mean) / tf.sqrt(fc_var + 0.001)
72+
Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b, fc_mean, fc_var, shift, scale, epsilon)
73+
74+
# activation
75+
if activation_function is None:
76+
outputs = Wx_plus_b
77+
else:
78+
outputs = activation_function(Wx_plus_b)
79+
80+
return outputs
81+
82+
fix_seed(1)
83+
# record inputs for every layer
84+
layers_inputs = [xs]
85+
86+
# build hidden layers
87+
for l_n in range(N_LAYERS):
88+
layer_input = layers_inputs[l_n]
89+
in_size = layers_inputs[l_n].get_shape()[1].value
90+
91+
output = add_layer(
92+
layer_input, # input
93+
in_size, # input size
94+
N_HIDDEN_UNITS, # output size
95+
ACTIVATION, # activation function
96+
norm, # normalize before activation
97+
)
98+
layers_inputs.append(output) # add output for next run
99+
100+
# build output layer
101+
prediction = add_layer(layers_inputs[-1], 30, 1, activation_function=None)
102+
103+
cost = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), reduction_indices=[1]))
104+
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(cost)
105+
return [train_op, cost, layers_inputs]
106+
107+
# make up data
108+
fix_seed(1)
109+
x_data = np.linspace(-7, 10, 500)[:, np.newaxis]
110+
noise = np.random.normal(0, 8, x_data.shape)
111+
y_data = np.square(x_data) - 5 + noise
112+
113+
# plot input data
114+
# plt.scatter(x_data, y_data)
115+
# plt.show()
116+
117+
xs = tf.placeholder(tf.float32, [None, 1]) # [num_samples, num_features]
118+
ys = tf.placeholder(tf.float32, [None, 1])
119+
120+
train_op, cost, layers_inputs = built_net(xs, ys, norm=False) # without BN
121+
train_op_norm, cost_norm, layers_inputs_norm = built_net(xs, ys, norm=True) # with BN
122+
123+
sess = tf.Session()
124+
sess.run(tf.global_variables_initializer())
125+
126+
# record cost
127+
cost_his = []
128+
cost_his_norm = []
129+
record_step = 5
130+
131+
plt.ion()
132+
plt.figure(figsize=(7, 3))
133+
for i in range(251):
134+
if i % 50 == 0:
135+
# plot histogram
136+
all_inputs, all_inputs_norm = sess.run([layers_inputs, layers_inputs_norm], feed_dict={xs: x_data, ys: y_data})
137+
plot_his(all_inputs, all_inputs_norm)
138+
139+
sess.run(train_op, feed_dict={xs: x_data, ys: y_data})
140+
sess.run(train_op_norm, feed_dict={xs: x_data, ys: y_data})
141+
if i % record_step == 0:
142+
# record cost
143+
cost_his.append(sess.run(cost, feed_dict={xs: x_data, ys: y_data}))
144+
cost_his_norm.append(sess.run(cost_norm, feed_dict={xs: x_data, ys: y_data}))
145+
146+
plt.ioff()
147+
plt.figure()
148+
plt.plot(np.arange(len(cost_his))*record_step, np.array(cost_his), label='no BN') # no norm
149+
plt.plot(np.arange(len(cost_his))*record_step, np.array(cost_his_norm), label='BN') # norm
150+
plt.legend()
151+
plt.show()
152+
153+

0 commit comments

Comments
 (0)