forked from donnyyou/torchcv
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrunner_selector.py
128 lines (107 loc) · 3.93 KB
/
runner_selector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Author: Donny You([email protected])
from runner.cls.image_classifier import ImageClassifier
from runner.cls.image_classifier_test import ImageClassifierTest
from runner.det.faster_rcnn import FasterRCNN
from runner.det.faster_rcnn_test import FastRCNNTest
from runner.det.single_shot_detector import SingleShotDetector
from runner.det.single_shot_detector_test import SingleShotDetectorTest
from runner.det.yolov3 import YOLOv3
from runner.det.yolov3_test import YOLOv3Test
from runner.pose.pose_estimator import PoseEstimator
from runner.pose.conv_pose_machine_test import ConvPoseMachineTest
from runner.pose.open_pose_test import OpenPoseTest
from runner.seg.fcn_segmentor import FCNSegmentor
from runner.seg.fcn_segmentor_test import FCNSegmentorTest
from runner.gan.image_translator import ImageTranslator
from runner.gan.image_translator_test import ImageTranslatorTest
from runner.gan.face_gan import FaceGAN
from runner.gan.face_gan_test import FaceGANTest
from lib.tools.util.logger import Logger as Log
POSE_METHOD_DICT = {
'open_pose': PoseEstimator,
'conv_pose_machine': PoseEstimator,
}
POSE_TEST_DICT = {
'open_pose': OpenPoseTest,
'conv_pose_machine': ConvPoseMachineTest,
}
SEG_METHOD_DICT = {
'fcn_segmentor': FCNSegmentor,
}
SEG_TEST_DICT = {
'fcn_segmentor': FCNSegmentorTest,
}
DET_METHOD_DICT = {
'faster_rcnn': FasterRCNN,
'single_shot_detector': SingleShotDetector,
'yolov3': YOLOv3,
}
DET_TEST_DICT = {
'faster_rcnn': FastRCNNTest,
'single_shot_detector': SingleShotDetectorTest,
'yolov3': YOLOv3Test,
}
CLS_METHOD_DICT = {
'image_classifier': ImageClassifier,
}
CLS_TEST_DICT = {
'image_classifier': ImageClassifierTest,
}
GAN_METHOD_DICT = {
'image_translator': ImageTranslator,
'face_gan': FaceGAN
}
GAN_TEST_DICT = {
'image_translator': ImageTranslatorTest,
'face_gan': FaceGANTest
}
class RunnerSelector(object):
def __init__(self, configer):
self.configer = configer
def pose_runner(self):
key = self.configer.get('method')
if key not in POSE_METHOD_DICT or key not in POSE_TEST_DICT:
Log.error('Pose Method: {} is not valid.'.format(key))
exit(1)
if self.configer.get('phase') == 'train':
return POSE_METHOD_DICT[key](self.configer)
else:
return POSE_TEST_DICT[key](self.configer)
def det_runner(self):
key = self.configer.get('method')
if key not in DET_METHOD_DICT or key not in DET_TEST_DICT:
Log.error('Det Method: {} is not valid.'.format(key))
exit(1)
if self.configer.get('phase') == 'train':
return DET_METHOD_DICT[key](self.configer)
else:
return DET_TEST_DICT[key](self.configer)
def seg_runner(self):
key = self.configer.get('method')
if key not in SEG_METHOD_DICT or key not in SEG_TEST_DICT:
Log.error('Det Method: {} is not valid.'.format(key))
exit(1)
if self.configer.get('phase') == 'train':
return SEG_METHOD_DICT[key](self.configer)
else:
return SEG_TEST_DICT[key](self.configer)
def cls_runner(self):
key = self.configer.get('method')
if key not in CLS_METHOD_DICT or key not in CLS_TEST_DICT:
Log.error('Cls Method: {} is not valid.'.format(key))
exit(1)
if self.configer.get('phase') == 'train':
return CLS_METHOD_DICT[key](self.configer)
else:
return CLS_TEST_DICT[key](self.configer)
def gan_runner(self):
key = self.configer.get('method')
if key not in GAN_METHOD_DICT or key not in GAN_TEST_DICT:
Log.error('Cls Method: {} is not valid.'.format(key))
exit(1)
if self.configer.get('phase') == 'train':
return GAN_METHOD_DICT[key](self.configer)
else:
return GAN_TEST_DICT[key](self.configer)