Skip to content

Commit 63ea4ae

Browse files
committed
edit
1 parent 2930a16 commit 63ea4ae

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

pytorchTUT/306_optimizer.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
Know more, visit 莫烦Python: https://morvanzhou.github.io/tutorials/
3+
My Youtube Channel: https://www.youtube.com/user/MorvanZhou
4+
5+
Dependencies:
6+
torch: 0.1.11
7+
"""
8+
import torch
9+
import torch.utils.data as Data
10+
import torch.nn.functional as F
11+
from torch.autograd import Variable
12+
import matplotlib.pyplot as plt
13+
14+
torch.manual_seed(1) # reproducible
15+
16+
LR = 0.01
17+
BATCH_SIZE = 32
18+
EPOCH = 12
19+
20+
# fake dataset
21+
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
22+
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))
23+
24+
# plot dataset
25+
plt.scatter(x.numpy(), y.numpy())
26+
plt.show()
27+
28+
# put dateset into torch dataset
29+
torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
30+
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,)
31+
32+
33+
# default network
34+
class Net(torch.nn.Module):
35+
def __init__(self):
36+
super(Net, self).__init__()
37+
self.hidden = torch.nn.Linear(1, 20) # hidden layer
38+
self.predict = torch.nn.Linear(20, 1) # output layer
39+
40+
def forward(self, x):
41+
x = F.relu(self.hidden(x)) # activation function for hidden layer
42+
x = self.predict(x) # linear output
43+
return x
44+
45+
# different nets
46+
net_SGD = Net()
47+
net_Momentum = Net()
48+
net_RMSprop = Net()
49+
net_Adam = Net()
50+
nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]
51+
52+
# different optimizers
53+
opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
54+
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
55+
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
56+
opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
57+
optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]
58+
59+
loss_func = torch.nn.MSELoss()
60+
losses_his = [[], [], [], []] # record loss
61+
62+
# training
63+
for epoch in range(EPOCH):
64+
print('Epoch: ', epoch)
65+
for step, (batch_x, batch_y) in enumerate(loader): # for each training step
66+
b_x = Variable(batch_x)
67+
b_y = Variable(batch_y)
68+
69+
for net, opt, l_his in zip(nets, optimizers, losses_his):
70+
output = net(b_x) # get output for every net
71+
loss = loss_func(output, b_y) # compute loss for every net
72+
opt.zero_grad() # clear gradients for next train
73+
loss.backward() # backpropagation, compute gradients
74+
opt.step() # apply gradients
75+
l_his.append(loss.data[0]) # loss recoder
76+
77+
labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
78+
for i, l_his in enumerate(losses_his):
79+
plt.plot(l_his, label=labels[i])
80+
plt.legend(loc='best')
81+
plt.xlabel('Steps')
82+
plt.ylabel('Loss')
83+
plt.ylim((0, 0.2))
84+
plt.show()

pytorchTUT/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ If you speak Chinese, you can watch my [Youtube channel](https://www.youtube.com
2626
* [An easy way](https://github.com/MorvanZhou/tutorials/blob/master/pytorchTUT/303_build_nn_quickly.py)
2727
* [Save and reload](https://github.com/MorvanZhou/tutorials/blob/master/pytorchTUT/304_save_reload.py)
2828
* [Train on batch](https://github.com/MorvanZhou/tutorials/blob/master/pytorchTUT/305_batch_train.py)
29+
* [Optimizers](https://github.com/MorvanZhou/tutorials/blob/master/pytorchTUT/306_optimizer.py)
2930
* Advanced neural network
3031
* [CNN](https://github.com/MorvanZhou/tutorials/blob/master/pytorchTUT/401_CNN.py)
3132
* [RNN-Classification](https://github.com/MorvanZhou/tutorials/blob/master/pytorchTUT/402_RNN_classifier.py)

0 commit comments

Comments
 (0)