Skip to content

Commit 05b50f6

Browse files
committed
edit
1 parent 6f5f782 commit 05b50f6

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed
+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# visit https://morvanzhou.github.io/tutorials/ for more!
2+
3+
4+
# 22 scope (name_scope/variable_scope)
5+
from __future__ import print_function
6+
import tensorflow as tf
7+
8+
class TrainConfig:
9+
batch_size = 20
10+
time_steps = 20
11+
input_size = 10
12+
output_size = 2
13+
cell_size = 11
14+
learning_rate = 0.01
15+
16+
17+
class TestConfig(TrainConfig):
18+
time_steps = 1
19+
20+
21+
class RNN(object):
22+
23+
def __init__(self, config):
24+
self._batch_size = config.batch_size
25+
self._time_steps = config.time_steps
26+
self._input_size = config.input_size
27+
self._output_size = config.output_size
28+
self._cell_size = config.cell_size
29+
self._lr = config.learning_rate
30+
self._built_RNN()
31+
32+
def _built_RNN(self):
33+
with tf.variable_scope('inputs'):
34+
self._xs = tf.placeholder(tf.float32, [self._batch_size, self._time_steps, self._input_size], name='xs')
35+
self._ys = tf.placeholder(tf.float32, [self._batch_size, self._time_steps, self._output_size], name='ys')
36+
with tf.name_scope('RNN'):
37+
with tf.variable_scope('input_layer'):
38+
l_in_x = tf.reshape(self._xs, [-1, self._input_size], name='2_2D') # (batch*n_step, in_size)
39+
# Ws (in_size, cell_size)
40+
Wi = self._weight_variable([self._input_size, self._cell_size])
41+
print(Wi.name)
42+
# bs (cell_size, )
43+
bi = self._bias_variable([self._cell_size, ])
44+
# l_in_y = (batch * n_steps, cell_size)
45+
with tf.name_scope('Wx_plus_b'):
46+
l_in_y = tf.matmul(l_in_x, Wi) + bi
47+
l_in_y = tf.reshape(l_in_y, [-1, self._time_steps, self._cell_size], name='2_3D')
48+
49+
with tf.variable_scope('cell'):
50+
cell = tf.nn.rnn_cell.BasicRNNCell(self._cell_size)
51+
with tf.name_scope('initial_state'):
52+
self._cell_initial_state = cell.zero_state(self._batch_size, dtype=tf.float32)
53+
54+
self.cell_outputs = []
55+
cell_state = self._cell_initial_state
56+
for t in range(self._time_steps):
57+
if t > 0: tf.get_variable_scope().reuse_variables()
58+
cell_output, cell_state = cell(l_in_y[:, t, :], cell_state)
59+
self.cell_outputs.append(cell_output)
60+
self._cell_final_state = cell_state
61+
62+
with tf.variable_scope('output_layer'):
63+
# cell_outputs_reshaped (BATCH*TIME_STEP, CELL_SIZE)
64+
cell_outputs_reshaped = tf.reshape(tf.concat(1, self.cell_outputs), [-1, self._cell_size])
65+
Wo = self._weight_variable((self._cell_size, self._output_size))
66+
bo = self._bias_variable((self._output_size,))
67+
product = tf.matmul(cell_outputs_reshaped, Wo) + bo
68+
# _pred shape (batch*time_step, output_size)
69+
self._pred = tf.nn.relu(product) # for displacement
70+
71+
with tf.name_scope('cost'):
72+
_pred = tf.reshape(self._pred, [self._batch_size, self._time_steps, self._output_size])
73+
mse = self.ms_error(_pred, self._ys)
74+
mse_ave_across_batch = tf.reduce_mean(mse, 0)
75+
mse_sum_across_time = tf.reduce_sum(mse_ave_across_batch, 0)
76+
self._cost = mse_sum_across_time
77+
self._cost_ave_time = self._cost / self._time_steps
78+
79+
with tf.name_scope('trian'):
80+
self._lr = tf.convert_to_tensor(self._lr)
81+
self.train_op = tf.train.AdamOptimizer(self._lr).minimize(self._cost)
82+
83+
@staticmethod
84+
def ms_error(y_pre, y_target):
85+
return tf.square(tf.sub(y_pre, y_target))
86+
87+
@staticmethod
88+
def _weight_variable(shape, name='weights'):
89+
initializer = tf.random_normal_initializer(mean=0., stddev=0.5, )
90+
return tf.get_variable(shape=shape, initializer=initializer, name=name)
91+
92+
@staticmethod
93+
def _bias_variable(shape, name='biases'):
94+
initializer = tf.constant_initializer(0.1)
95+
return tf.get_variable(name=name, shape=shape, initializer=initializer)
96+
97+
98+
if __name__ == '__main__':
99+
train_config = TrainConfig()
100+
test_config = TestConfig()
101+
102+
# the wrong method to reuse parameters in train rnn
103+
with tf.variable_scope('train_rnn'):
104+
train_rnn1 = RNN(train_config)
105+
with tf.variable_scope('test_rnn'):
106+
test_rnn1 = RNN(test_config)
107+
108+
# the right method to reuse parameters in train rnn
109+
with tf.variable_scope('rnn') as scope:
110+
sess = tf.Session()
111+
train_rnn2 = RNN(train_config)
112+
scope.reuse_variables()
113+
test_rnn2 = RNN(test_config)
114+
# tf.initialize_all_variables() no long valid from
115+
# 2017-03-02 if using tensorflow >= 0.12
116+
if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
117+
init = tf.initialize_all_variables()
118+
else:
119+
init = tf.global_variables_initializer()
120+
sess.run(init)

0 commit comments

Comments
 (0)