1
+ # visit https://morvanzhou.github.io/tutorials/ for more!
2
+
3
+
4
+ # 22 scope (name_scope/variable_scope)
5
+ from __future__ import print_function
6
+ import tensorflow as tf
7
+
8
+ class TrainConfig :
9
+ batch_size = 20
10
+ time_steps = 20
11
+ input_size = 10
12
+ output_size = 2
13
+ cell_size = 11
14
+ learning_rate = 0.01
15
+
16
+
17
+ class TestConfig (TrainConfig ):
18
+ time_steps = 1
19
+
20
+
21
+ class RNN (object ):
22
+
23
+ def __init__ (self , config ):
24
+ self ._batch_size = config .batch_size
25
+ self ._time_steps = config .time_steps
26
+ self ._input_size = config .input_size
27
+ self ._output_size = config .output_size
28
+ self ._cell_size = config .cell_size
29
+ self ._lr = config .learning_rate
30
+ self ._built_RNN ()
31
+
32
+ def _built_RNN (self ):
33
+ with tf .variable_scope ('inputs' ):
34
+ self ._xs = tf .placeholder (tf .float32 , [self ._batch_size , self ._time_steps , self ._input_size ], name = 'xs' )
35
+ self ._ys = tf .placeholder (tf .float32 , [self ._batch_size , self ._time_steps , self ._output_size ], name = 'ys' )
36
+ with tf .name_scope ('RNN' ):
37
+ with tf .variable_scope ('input_layer' ):
38
+ l_in_x = tf .reshape (self ._xs , [- 1 , self ._input_size ], name = '2_2D' ) # (batch*n_step, in_size)
39
+ # Ws (in_size, cell_size)
40
+ Wi = self ._weight_variable ([self ._input_size , self ._cell_size ])
41
+ print (Wi .name )
42
+ # bs (cell_size, )
43
+ bi = self ._bias_variable ([self ._cell_size , ])
44
+ # l_in_y = (batch * n_steps, cell_size)
45
+ with tf .name_scope ('Wx_plus_b' ):
46
+ l_in_y = tf .matmul (l_in_x , Wi ) + bi
47
+ l_in_y = tf .reshape (l_in_y , [- 1 , self ._time_steps , self ._cell_size ], name = '2_3D' )
48
+
49
+ with tf .variable_scope ('cell' ):
50
+ cell = tf .nn .rnn_cell .BasicRNNCell (self ._cell_size )
51
+ with tf .name_scope ('initial_state' ):
52
+ self ._cell_initial_state = cell .zero_state (self ._batch_size , dtype = tf .float32 )
53
+
54
+ self .cell_outputs = []
55
+ cell_state = self ._cell_initial_state
56
+ for t in range (self ._time_steps ):
57
+ if t > 0 : tf .get_variable_scope ().reuse_variables ()
58
+ cell_output , cell_state = cell (l_in_y [:, t , :], cell_state )
59
+ self .cell_outputs .append (cell_output )
60
+ self ._cell_final_state = cell_state
61
+
62
+ with tf .variable_scope ('output_layer' ):
63
+ # cell_outputs_reshaped (BATCH*TIME_STEP, CELL_SIZE)
64
+ cell_outputs_reshaped = tf .reshape (tf .concat (1 , self .cell_outputs ), [- 1 , self ._cell_size ])
65
+ Wo = self ._weight_variable ((self ._cell_size , self ._output_size ))
66
+ bo = self ._bias_variable ((self ._output_size ,))
67
+ product = tf .matmul (cell_outputs_reshaped , Wo ) + bo
68
+ # _pred shape (batch*time_step, output_size)
69
+ self ._pred = tf .nn .relu (product ) # for displacement
70
+
71
+ with tf .name_scope ('cost' ):
72
+ _pred = tf .reshape (self ._pred , [self ._batch_size , self ._time_steps , self ._output_size ])
73
+ mse = self .ms_error (_pred , self ._ys )
74
+ mse_ave_across_batch = tf .reduce_mean (mse , 0 )
75
+ mse_sum_across_time = tf .reduce_sum (mse_ave_across_batch , 0 )
76
+ self ._cost = mse_sum_across_time
77
+ self ._cost_ave_time = self ._cost / self ._time_steps
78
+
79
+ with tf .name_scope ('trian' ):
80
+ self ._lr = tf .convert_to_tensor (self ._lr )
81
+ self .train_op = tf .train .AdamOptimizer (self ._lr ).minimize (self ._cost )
82
+
83
+ @staticmethod
84
+ def ms_error (y_pre , y_target ):
85
+ return tf .square (tf .sub (y_pre , y_target ))
86
+
87
+ @staticmethod
88
+ def _weight_variable (shape , name = 'weights' ):
89
+ initializer = tf .random_normal_initializer (mean = 0. , stddev = 0.5 , )
90
+ return tf .get_variable (shape = shape , initializer = initializer , name = name )
91
+
92
+ @staticmethod
93
+ def _bias_variable (shape , name = 'biases' ):
94
+ initializer = tf .constant_initializer (0.1 )
95
+ return tf .get_variable (name = name , shape = shape , initializer = initializer )
96
+
97
+
98
+ if __name__ == '__main__' :
99
+ train_config = TrainConfig ()
100
+ test_config = TestConfig ()
101
+
102
+ # the wrong method to reuse parameters in train rnn
103
+ with tf .variable_scope ('train_rnn' ):
104
+ train_rnn1 = RNN (train_config )
105
+ with tf .variable_scope ('test_rnn' ):
106
+ test_rnn1 = RNN (test_config )
107
+
108
+ # the right method to reuse parameters in train rnn
109
+ with tf .variable_scope ('rnn' ) as scope :
110
+ sess = tf .Session ()
111
+ train_rnn2 = RNN (train_config )
112
+ scope .reuse_variables ()
113
+ test_rnn2 = RNN (test_config )
114
+ # tf.initialize_all_variables() no long valid from
115
+ # 2017-03-02 if using tensorflow >= 0.12
116
+ if int ((tf .__version__ ).split ('.' )[1 ]) < 12 and int ((tf .__version__ ).split ('.' )[0 ]) < 1 :
117
+ init = tf .initialize_all_variables ()
118
+ else :
119
+ init = tf .global_variables_initializer ()
120
+ sess .run (init )
0 commit comments