forked from tinyvision/DAMO-YOLO
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
155 lines (129 loc) · 4.7 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env python3
# Copyright (C) Alibaba Group Holding Limited. All rights reserved.
import argparse
import os
import torch
from loguru import logger
from damo.base_models.core.ops import RepConv
from damo.apis.detector_inference import inference
from damo.config.base import parse_config
from damo.dataset import build_dataloader, build_dataset
from damo.detectors.detector import build_ddp_model, build_local_model
from damo.utils import fuse_model, get_model_info, setup_logger, synchronize
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def make_parser():
parser = argparse.ArgumentParser('damo eval')
# distributed
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'-f',
'--config_file',
default=None,
type=str,
help='pls input your config file',
)
parser.add_argument('-c',
'--ckpt',
default=None,
type=str,
help='ckpt for eval')
parser.add_argument('--conf', default=None, type=float, help='test conf')
parser.add_argument('--nms',
default=None,
type=float,
help='test nms threshold')
parser.add_argument('--tsize',
default=None,
type=int,
help='test img size')
parser.add_argument('--seed', default=None, type=int, help='eval seed')
parser.add_argument(
'--fuse',
dest='fuse',
default=False,
action='store_true',
help='Fuse conv and bn for testing.',
)
parser.add_argument(
'--test',
dest='test',
default=False,
action='store_true',
help='Evaluating on test-dev set.',
) # TODO
parser.add_argument(
'opts',
help='Modify config options using the command-line',
default=None,
nargs=argparse.REMAINDER,
)
return parser
@logger.catch
def main():
args = make_parser().parse_args()
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
synchronize()
device = 'cuda'
config = parse_config(args.config_file)
config.merge(args.opts)
file_name = os.path.join(config.miscs.output_dir, config.miscs.exp_name)
if args.local_rank == 0:
os.makedirs(file_name, exist_ok=True)
setup_logger(file_name,
distributed_rank=args.local_rank,
filename='val_log.txt',
mode='a')
logger.info('Args: {}'.format(args))
model = build_local_model(config, device)
model.head.nms = True
logger.info('Model Summary: {}'.format(get_model_info(model, (640, 640))))
model = build_ddp_model(model, local_rank=args.local_rank)
model.cuda(args.local_rank)
model.eval()
ckpt_file = args.ckpt
logger.info('loading checkpoint from {}'.format(ckpt_file))
loc = 'cuda:{}'.format(args.local_rank)
ckpt = torch.load(ckpt_file, map_location=loc)
new_state_dict = {}
for k, v in ckpt['model'].items():
k = 'module.' + k
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
logger.info('loaded checkpoint done.')
for layer in model.modules():
if isinstance(layer, RepConv):
layer.switch_to_deploy()
if args.fuse:
logger.info('\tFusing model...')
model = fuse_model(model)
# start evaluate
output_folders = [None] * len(config.dataset.val_ann)
if args.local_rank == 0 and config.miscs.output_dir:
for idx, dataset_name in enumerate(config.dataset.val_ann):
output_folder = os.path.join(config.miscs.output_dir, 'inference',
dataset_name)
mkdir(output_folder)
output_folders[idx] = output_folder
val_dataset = build_dataset(config, config.dataset.val_ann, is_train=False)
val_loader = build_dataloader(val_dataset,
config.test.augment,
batch_size=config.test.batch_size,
num_workers=config.miscs.num_workers,
is_train=False,
size_div=32)
for output_folder, dataset_name, data_loader_val in zip(
output_folders, config.dataset.val_ann, val_loader):
inference(
model,
data_loader_val,
dataset_name,
iou_types=('bbox', ),
box_only=False,
device=device,
output_folder=output_folder,
)
if __name__ == '__main__':
main()