-
Notifications
You must be signed in to change notification settings - Fork 5
/
train_val.py
81 lines (61 loc) · 2.64 KB
/
train_val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(ROOT_DIR)
import yaml
import argparse
import datetime
from lib.helpers.model_helper import build_model
from lib.helpers.dataloader_helper import build_dataloader
from lib.helpers.optimizer_helper import build_optimizer
from lib.helpers.scheduler_helper import build_lr_scheduler
from lib.helpers.trainer_helper import Trainer
from lib.helpers.tester_helper import Tester
from lib.helpers.utils_helper import create_logger
from lib.helpers.utils_helper import set_random_seed
parser = argparse.ArgumentParser(description='End-to-End Monocular 3D Object Detection')
parser.add_argument('--config', dest='config', help='settings of detection in yaml format')
parser.add_argument('-e', '--evaluate_only', action='store_true', default=False, help='evaluation only')
args = parser.parse_args()
def main():
assert (os.path.exists(args.config))
cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
set_random_seed(cfg.get('random_seed', 444))
log_file = 'train.log.%s' % datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
logger = create_logger(log_file)
# build dataloader
train_loader, test_loader = build_dataloader(cfg['dataset'])
# build model
model = build_model(cfg['model'])
if args.evaluate_only:
logger.info('################### Evaluation Only ##################')
tester = Tester(cfg=cfg['tester'],
model=model,
dataloader=test_loader,
logger=logger)
tester.test()
return
# build optimizer
optimizer = build_optimizer(cfg['optimizer'], model)
# build lr scheduler
lr_scheduler, warmup_lr_scheduler = build_lr_scheduler(cfg['lr_scheduler'], optimizer, last_epoch=-1)
logger.info('################### Training ##################')
logger.info('Batch Size: %d' % (cfg['dataset']['batch_size']))
logger.info('Learning Rate: %f' % (cfg['optimizer']['lr']))
trainer = Trainer(cfg=cfg['trainer'],
model=model,
optimizer=optimizer,
train_loader=train_loader,
lr_scheduler=lr_scheduler,
warmup_lr_scheduler=warmup_lr_scheduler,
logger=logger)
trainer.train()
logger.info('################### Evaluation ##################' )
tester = Tester(cfg=cfg['tester'],
model=model,
dataloader=test_loader,
logger=logger)
tester.test()
if __name__ == '__main__':
main()