Skip to content

Commit d12705d

Browse files
author
Mofan Zhou
committed
update theano TUT
1 parent 0354275 commit d12705d

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

theanoTUT/theano11_classification_nn/full_code.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,26 @@ def compute_accuracy(y_target, y_predict):
2929
y = T.dvector("y")
3030

3131
# initialize the weights and biases
32-
w = theano.shared(rng.randn(feats), name="w")
32+
W = theano.shared(rng.randn(feats), name="w")
3333
b = theano.shared(0., name="b")
3434

3535

3636
# Construct Theano expression graph
37-
p_1 = T.nnet.sigmoid(T.dot(x, w) + b) # Logistic Probability that target = 1 (activation function)
37+
p_1 = T.nnet.sigmoid(T.dot(x, W) + b) # Logistic Probability that target = 1 (activation function)
3838
prediction = p_1 > 0.5 # The prediction thresholded
3939
xent = -y * T.log(p_1) - (1-y) * T.log(1-p_1) # Cross-entropy loss function
40-
cost = xent.mean() + 0.01 * (w ** 2).sum()# The cost to minimize (l2 regularization)
41-
gw, gb = T.grad(cost, [w, b]) # Compute the gradient of the cost
40+
# or
41+
# xent = T.nnet.binary_crossentropy(p_1, y) # this is provided by theano
42+
cost = xent.mean() + 0.01 * (W ** 2).sum()# The cost to minimize (l2 regularization)
43+
gW, gb = T.grad(cost, [W, b]) # Compute the gradient of the cost
4244

4345

4446
# Compile
4547
learning_rate = 0.1
4648
train = theano.function(
4749
inputs=[x, y],
4850
outputs=[prediction, xent.mean()],
49-
updates=((w, w - learning_rate * gw), (b, b - learning_rate * gb)))
51+
updates=((W, W - learning_rate * gW), (b, b - learning_rate * gb)))
5052
predict = theano.function(inputs=[x], outputs=prediction)
5153

5254
# Training

theanoTUT/theano12_cross_validation/for_you_to_practice.py theanoTUT/theano12_regularization/for_you_to_practice.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Youtube video tutorial: https://www.youtube.com/channel/UCdyjiB5H8Pu7aDTNVXTTpcg
44
# Youku video tutorial: http://i.youku.com/pythontutorial
55

6-
# 12 - cross validation
6+
# 12 - regularization
77
"""
88
Please note, this code is only for python 3+. If you are using python 2+, please modify the code accordingly.
99
"""

theanoTUT/theano12_cross_validation/full_code.py theanoTUT/theano12_regularization/full_code.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Youtube video tutorial: https://www.youtube.com/channel/UCdyjiB5H8Pu7aDTNVXTTpcg
44
# Youku video tutorial: http://i.youku.com/pythontutorial
55

6-
# 12 - cross validation
6+
# 12 - regularization
77
"""
88
Please note, this code is only for python 3+. If you are using python 2+, please modify the code accordingly.
99
"""

0 commit comments

Comments
 (0)