forked from mivlab/AI_course
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_outline.py
65 lines (60 loc) · 2.43 KB
/
train_outline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import math
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms, models
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, label, transform=None):
self.data = data
self.label = label
self.transform = transform
def __getitem__(self, index):
return self.data[index], self.label[index]
def __len__(self):
return len(self.data)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(3, 32, 3, 1, 1),
nn.ReLU(), nn.MaxPool2d(2))
self.fc = nn.Sequential(nn.Linear(in_channel, 10))
def forward(self, x):
x = self.conv1(x)
x = x.view(x.size(0), -1)
out = self.fc(x)
return out
def train():
use_cuda = torch.cuda.is_available()
train_data = MyDataset(data, label, transform=transforms.ToTensor())
val_data = MyDataset(data, label, transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=128, shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=128)
model = Net()
#model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20], 0.1)
loss_func = nn.CrossEntropyLoss()
for epoch in range(30):
# training-----------------------------------
model.train()
for batch, (batch_x, batch_y) in enumerate(train_loader):
#batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
out = model(batch_x) # 256x3x28x28 out 256x10
loss = loss_func(out, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step() # 更新learning rate
# evaluation--------------------------------
model.eval()
for batch_x, batch_y in val_loader:
#batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
out = model(batch_x)
# save model --------------------------------
torch.save(model.state_dict(), 'params_' + str(epoch + 1) + '.pth')
if __name__ == '__main__':
train()