-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
81 lines (67 loc) · 3.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import argparse
import json
import torch
from src.trainer import train
from src.utils import load_net, load_data, eval_accuracy
def get_args():
argparser = argparse.ArgumentParser(description=__doc__)
argparser.add_argument('--gpuid',
default='0,', help='gpu id, [0] ')
argparser.add_argument('--dataset',
default='fashionmnist', help='dataset, [fashionmnist] | cifar10, 1dfunction')
argparser.add_argument('--network', default='vgg',
help='network, [vgg] | fnn, resnet')
argparser.add_argument('--num_classes', type=int, default=2)
argparser.add_argument('--n_samples_per_class', type=int,
default=500, help='training set size, [1000]')
argparser.add_argument('--load_size', type=int,
default=1000, help='load size for dataset, [1000]')
argparser.add_argument('--optimizer',
default='sgd', help='optimizer, [sgd]')
argparser.add_argument('--n_iters', type=int,
default=10000, help='number of iteration used to train nets, [10000]')
argparser.add_argument('--batch_size', type=int,
default=1000, help='batch size, [1000]')
argparser.add_argument('--learning_rate', type=float,
default=1e-1, help='learning rate')
argparser.add_argument('--momentum', type=float,
default='0.0', help='momentum, [0.0]')
argparser.add_argument('--model_file',
default='fnn.pkl', help='filename to save the net, fnn.pkl')
args = argparser.parse_args()
if args.load_size > args.batch_size:
raise ValueError('load size should not be larger than batch size')
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpuid
print('===> Config:')
print(json.dumps(vars(args), indent=2))
return args
def get_optimizer(net, args):
if args.optimizer == 'sgd':
return torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=args.momentum)
elif args.optimizer == 'adam':
return torch.optim.Adam(net.parameters(), lr=args.learning_rate)
else:
raise ValueError('optimizer %s has not been supported'%(args.optimizer))
def main():
args = get_args()
criterion = torch.nn.MSELoss()
train_loader, test_loader = load_data(args.dataset,
args.num_classes,
train_per_class=args.n_samples_per_class,
batch_size=args.load_size)
net = load_net(args.network, args.dataset, args.num_classes)
optimizer = get_optimizer(net, args)
print(optimizer)
print('===> Architecture:')
print(net)
print('===> Start training')
train(net, criterion, optimizer, train_loader, args.batch_size, args.n_iters, verbose=True)
train_loss, train_accuracy = eval_accuracy(net, criterion, train_loader)
test_loss, test_accuracy = eval_accuracy(net, criterion, test_loader)
print('===> Solution: ')
print('\t train loss: %.2e, acc: %.2f' % (train_loss, train_accuracy))
print('\t test loss: %.2e, acc: %.2f' % (test_loss, test_accuracy))
torch.save(net.state_dict(), 'res/'+args.model_file)
if __name__ == '__main__':
main()