Skip to content

Commit c12f7a4

Browse files
committedOct 14, 2021
feat : mmdetection/inference.py | chore : mmdtection/train.py
1 parent 9e6ff2a commit c12f7a4

File tree

2 files changed

+115
-6
lines changed

2 files changed

+115
-6
lines changed
 

‎mmdetection/inference.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import os
2+
import torch
3+
import argparse
4+
import pandas as pd
5+
from mmcv import Config
6+
from mmdet.datasets import build_dataloader, build_dataset
7+
from mmdet.models import build_detector
8+
from mmdet.apis import single_gpu_test
9+
from mmcv.runner import load_checkpoint
10+
from mmcv.parallel import MMDataParallel
11+
from pycocotools.coco import COCO
12+
from GPUtil import showUtilization as gpu_usage
13+
14+
def empty_cache():
15+
"""
16+
GPU cache를 비우는 함수
17+
"""
18+
print("Initial GPU Usage")
19+
gpu_usage()
20+
print("GPU Usage after emptying the cache")
21+
torch.cuda.empty_cache()
22+
gpu_usage()
23+
24+
def save_csv(output, cfg, epoch):
25+
"""
26+
Inference 결과를 csv 파일로 저장하는 함수
27+
"""
28+
# submission 양식에 맞게 output 후처리
29+
prediction_strings = []
30+
file_names = []
31+
coco = COCO(cfg.data.test.ann_file)
32+
img_ids = coco.getImgIds()
33+
34+
class_num = 10
35+
# for i, out in enumerate(output):
36+
for i, out in zip(img_ids, output):
37+
prediction_string = ''
38+
try:
39+
image_info = coco.loadImgs(coco.getImgIds(imgIds=i))[0]
40+
except:
41+
continue
42+
for j in range(class_num):
43+
for o in out[j]:
44+
prediction_string += str(j) + ' ' + str(o[4]) + ' ' + str(o[0]) + ' ' + str(o[1]) + ' ' + str(
45+
o[2]) + ' ' + str(o[3]) + ' '
46+
47+
prediction_strings.append(prediction_string)
48+
file_names.append(image_info['file_name'])
49+
50+
submission = pd.DataFrame()
51+
submission['PredictionString'] = prediction_strings
52+
submission['image_id'] = file_names
53+
submission.to_csv(os.path.join(cfg.work_dir, f'submission_{epoch}.csv'), index=None)
54+
55+
def main(args):
56+
empty_cache()
57+
# config file 들고오기
58+
config_dir = args.config_dir
59+
config_file = args.config_file
60+
cfg = Config.fromfile(f'./configs/{config_dir}/{config_file}.py')
61+
62+
epoch = args.ckpt_name
63+
64+
cfg.data.test.test_mode = True
65+
66+
cfg.data.samples_per_gpu = args.batch_size
67+
68+
cfg.seed=args.seed
69+
cfg.gpu_ids = [1]
70+
cfg.work_dir = os.path.join('./work_dirs', config_file)
71+
72+
cfg.optimizer_config.grad_clip = dict(max_norm=35, norm_type=2)
73+
cfg.model.train_cfg = None
74+
75+
# build dataset & dataloader
76+
dataset = build_dataset(cfg.data.test)
77+
data_loader = build_dataloader(
78+
dataset,
79+
samples_per_gpu=1,
80+
workers_per_gpu=cfg.data.workers_per_gpu,
81+
dist=False,
82+
shuffle=False)
83+
84+
# checkpoint path
85+
checkpoint_path = os.path.join(cfg.work_dir, f'{epoch}.pth')
86+
87+
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) # build detector
88+
checkpoint = load_checkpoint(model, checkpoint_path, map_location='cpu') # ckpt load
89+
90+
model.CLASSES = dataset.CLASSES
91+
model = MMDataParallel(model.cuda(), device_ids=[0])
92+
93+
output = single_gpu_test(model, data_loader, show_score_thr=0.05) # output 계산
94+
95+
save_csv(output, cfg, epoch) # save csv
96+
97+
if __name__ == '__main__':
98+
parser = argparse.ArgumentParser()
99+
100+
parser.add_argument('--seed', type=int, nargs='?', default=1995, help='random seed (default: 1995)')
101+
parser.add_argument('--batch_size', type=int, nargs='?', default=2, help='batch size (default: 2)')
102+
103+
#checkpoint
104+
parser.add_argument('--ckpt_name', type=str, nargs='?', default='latest')
105+
106+
# directory, file path
107+
parser.add_argument('--data_dir', type=str, nargs='?', default='/opt/ml/detection/dataset')
108+
109+
parser.add_argument('--config_dir', type=str, nargs='?', default='swin')
110+
parser.add_argument('--config_file', type=str, nargs='?', default='cascade_rcnn_swin-t-p4-w7_fpn_ms_mosaic_1x_coco_val')
111+
112+
args = parser.parse_args()
113+
114+
# running
115+
main(args)

‎mmdetection/train.py

-6
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
from mmcv.runner import load_checkpoint
1111
from GPUtil import showUtilization as gpu_usage
1212

13-
classes = ("General trash", "Paper", "Paper pack", "Metal", "Glass",
14-
"Plastic", "Styrofoam", "Plastic bag", "Battery", "Clothing")
15-
1613
def empty_cache():
1714
"""
1815
GPU cache를 비우는 함수
@@ -91,9 +88,6 @@ def main(args):
9188

9289
# directory, file path
9390
parser.add_argument('--data_dir', type=str, nargs='?', default='/opt/ml/detection/dataset')
94-
parser.add_argument('--train_file', type=str, nargs='?', default='train_split_0.json')
95-
parser.add_argument('--valid_file', type=str, nargs='?', default='valid_split_0.json')
96-
parser.add_argument('--test_file', type=str, nargs='?', default='test.json')
9791

9892
parser.add_argument('--config_dir', type=str, nargs='?', default='swin')
9993
parser.add_argument('--config_file', type=str, nargs='?', default='cascade_rcnn_swin-t-p4-w7_fpn_ms_mosaic_1x_coco_val')

0 commit comments

Comments
 (0)
Please sign in to comment.