forked from whai362/PSENet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
118 lines (94 loc) · 3.16 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import numpy as np
import argparse
import os
import os.path as osp
import sys
import time
import json
from mmcv import Config
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter
def report_speed(outputs, speed_meters):
total_time = 0
for key in outputs:
if 'time' in key:
total_time += outputs[key]
speed_meters[key].update(outputs[key])
print('%s: %.4f' % (key, speed_meters[key].avg))
speed_meters['total_time'].update(total_time)
print('FPS: %.1f' % (1.0 / speed_meters['total_time'].avg))
def test(test_loader, model, cfg):
model.eval()
rf = ResultFormat(cfg.data.test.type, cfg.test_cfg.result_path)
if cfg.report_speed:
speed_meters = dict(
backbone_time=AverageMeter(500),
neck_time=AverageMeter(500),
det_head_time=AverageMeter(500),
det_pse_time=AverageMeter(500),
rec_time=AverageMeter(500),
total_time=AverageMeter(500)
)
for idx, data in enumerate(test_loader):
print('Testing %d/%d' % (idx, len(test_loader)))
sys.stdout.flush()
# prepare input
data['imgs'] = data['imgs'].cuda()
data.update(dict(
cfg=cfg
))
# forward
with torch.no_grad():
outputs = model(**data)
if cfg.report_speed:
report_speed(outputs, speed_meters)
# save result
image_name, _ = osp.splitext(osp.basename(test_loader.dataset.img_paths[idx]))
# print('image_name', image_name)
rf.write_result(image_name, outputs)
def main(args):
cfg = Config.fromfile(args.config)
for d in [cfg, cfg.data.test]:
d.update(dict(
report_speed=args.report_speed
))
print(json.dumps(cfg._cfg_dict, indent=4))
sys.stdout.flush()
# data loader
data_loader = build_data_loader(cfg.data.test)
test_loader = torch.utils.data.DataLoader(
data_loader,
batch_size=1,
shuffle=False,
num_workers=2,
)
# model
model = build_model(cfg.model)
model = model.cuda()
if args.checkpoint is not None:
if os.path.isfile(args.checkpoint):
print("Loading model and optimizer from checkpoint '{}'".format(args.checkpoint))
sys.stdout.flush()
checkpoint = torch.load(args.checkpoint)
d = dict()
for key, value in checkpoint['state_dict'].items():
tmp = key[7:]
d[tmp] = value
model.load_state_dict(d)
else:
print("No checkpoint found at '{}'".format(args.resume))
raise
# fuse conv and bn
model = fuse_module(model)
# test
test(test_loader, model, cfg)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('config', help='config file path')
parser.add_argument('checkpoint', nargs='?', type=str, default=None)
parser.add_argument('--report_speed', action='store_true')
args = parser.parse_args()
main(args)