forked from yysijie/st-gcn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
processor.py
191 lines (159 loc) · 7.83 KB
/
processor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#!/usr/bin/env python
# pylint: disable=W0201
import sys
import argparse
import yaml
import numpy as np
# torch
import torch
import torch.nn as nn
import torch.optim as optim
# torchlight
import torchlight
from torchlight import str2bool
from torchlight import DictAction
from torchlight import import_class
from .io import IO
class Processor(IO):
"""
Base Processor
"""
def __init__(self, argv=None):
self.load_arg(argv)
self.init_environment()
self.load_model()
self.load_weights()
self.gpu()
self.load_data()
self.load_optimizer()
def init_environment(self):
super().init_environment()
self.result = dict()
self.iter_info = dict()
self.epoch_info = dict()
self.meta_info = dict(epoch=0, iter=0)
def load_optimizer(self):
pass
def load_data(self):
Feeder = import_class(self.arg.feeder)
if 'debug' not in self.arg.train_feeder_args:
self.arg.train_feeder_args['debug'] = self.arg.debug
self.data_loader = dict()
if self.arg.phase == 'train':
self.data_loader['train'] = torch.utils.data.DataLoader(
dataset=Feeder(**self.arg.train_feeder_args),
batch_size=self.arg.batch_size,
shuffle=True,
num_workers=self.arg.num_worker * torchlight.ngpu(
self.arg.device),
drop_last=True)
if self.arg.test_feeder_args:
self.data_loader['test'] = torch.utils.data.DataLoader(
dataset=Feeder(**self.arg.test_feeder_args),
batch_size=self.arg.test_batch_size,
shuffle=False,
num_workers=self.arg.num_worker * torchlight.ngpu(
self.arg.device))
def show_epoch_info(self):
for k, v in self.epoch_info.items():
self.io.print_log('\t{}: {}'.format(k, v))
if self.arg.pavi_log:
self.io.log('train', self.meta_info['iter'], self.epoch_info)
def show_iter_info(self):
if self.meta_info['iter'] % self.arg.log_interval == 0:
info ='\tIter {} Done.'.format(self.meta_info['iter'])
for k, v in self.iter_info.items():
if isinstance(v, float):
info = info + ' | {}: {:.4f}'.format(k, v)
else:
info = info + ' | {}: {}'.format(k, v)
self.io.print_log(info)
if self.arg.pavi_log:
self.io.log('train', self.meta_info['iter'], self.iter_info)
def train(self):
for _ in range(100):
self.iter_info['loss'] = 0
self.show_iter_info()
self.meta_info['iter'] += 1
self.epoch_info['mean loss'] = 0
self.show_epoch_info()
def test(self):
for _ in range(100):
self.iter_info['loss'] = 1
self.show_iter_info()
self.epoch_info['mean loss'] = 1
self.show_epoch_info()
def start(self):
self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg))))
# training phase
if self.arg.phase == 'train':
for epoch in range(self.arg.start_epoch, self.arg.num_epoch):
self.meta_info['epoch'] = epoch
# training
self.io.print_log('Training epoch: {}'.format(epoch))
self.train()
self.io.print_log('Done.')
# save model
if ((epoch + 1) % self.arg.save_interval == 0) or (
epoch + 1 == self.arg.num_epoch):
filename = 'epoch{}_model.pt'.format(epoch + 1)
self.io.save_model(self.model, filename)
# evaluation
if ((epoch + 1) % self.arg.eval_interval == 0) or (
epoch + 1 == self.arg.num_epoch):
self.io.print_log('Eval epoch: {}'.format(epoch))
self.test()
self.io.print_log('Done.')
# test phase
elif self.arg.phase == 'test':
# the path of weights must be appointed
if self.arg.weights is None:
raise ValueError('Please appoint --weights.')
self.io.print_log('Model: {}.'.format(self.arg.model))
self.io.print_log('Weights: {}.'.format(self.arg.weights))
# evaluation
self.io.print_log('Evaluation Start:')
self.test()
self.io.print_log('Done.\n')
# save the output of model
if self.arg.save_result:
result_dict = dict(
zip(self.data_loader['test'].dataset.sample_name,
self.result))
self.io.save_pkl(result_dict, 'test_result.pkl')
@staticmethod
def get_parser(add_help=False):
#region arguments yapf: disable
# parameter priority: command line > config > default
parser = argparse.ArgumentParser( add_help=add_help, description='Base Processor')
parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results')
parser.add_argument('-c', '--config', default=None, help='path to the configuration file')
# processor
parser.add_argument('--phase', default='train', help='must be train or test')
parser.add_argument('--save_result', type=str2bool, default=False, help='if ture, the output of the model will be stored')
parser.add_argument('--start_epoch', type=int, default=0, help='start training from which epoch')
parser.add_argument('--num_epoch', type=int, default=80, help='stop training in which epoch')
parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not')
parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing')
# visulize and debug
parser.add_argument('--log_interval', type=int, default=100, help='the interval for printing messages (#iteration)')
parser.add_argument('--save_interval', type=int, default=10, help='the interval for storing models (#iteration)')
parser.add_argument('--eval_interval', type=int, default=5, help='the interval for evaluating models (#iteration)')
parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not')
parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not')
parser.add_argument('--pavi_log', type=str2bool, default=False, help='logging on pavi or not')
# feeder
parser.add_argument('--feeder', default='feeder.feeder', help='data loader will be used')
parser.add_argument('--num_worker', type=int, default=4, help='the number of worker per gpu for data loader')
parser.add_argument('--train_feeder_args', action=DictAction, default=dict(), help='the arguments of data loader for training')
parser.add_argument('--test_feeder_args', action=DictAction, default=dict(), help='the arguments of data loader for test')
parser.add_argument('--batch_size', type=int, default=256, help='training batch size')
parser.add_argument('--test_batch_size', type=int, default=256, help='test batch size')
parser.add_argument('--debug', action="store_true", help='less data, faster loading')
# model
parser.add_argument('--model', default=None, help='the model will be used')
parser.add_argument('--model_args', action=DictAction, default=dict(), help='the arguments of model')
parser.add_argument('--weights', default=None, help='the weights for network initialization')
parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization')
#endregion yapf: enable
return parser