Skip to content

Commit 980558b

Browse files
committed
add PPYOLOE architecture
1 parent edd057d commit 980558b

File tree

8 files changed

+119
-14
lines changed

8 files changed

+119
-14
lines changed

deploy/python/infer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
# Global dictionary
4040
SUPPORT_MODELS = {
41-
'YOLO', 'YOLOX', 'YOLOv5', 'RTMDet', 'YOLOv6', 'YOLOv7', 'YOLOv8'
41+
'YOLO', 'PPYOLOE', 'YOLOX', 'YOLOv5', 'RTMDet', 'YOLOv6', 'YOLOv7', 'YOLOv8'
4242
}
4343

4444
TUNED_TRT_DYNAMIC_MODELS = {}

deploy/serving/python/web_service.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030

3131
# Global dictionary
3232
SUPPORT_MODELS = {
33-
'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
34-
'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet',
35-
'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
33+
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
34+
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
35+
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
3636
}
3737

3838
GLOBAL_VAR = {}

deploy/third_engine/demo_onnx_trt/trt_infer.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@
5151
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
5252
# Global dictionary
5353
SUPPORT_MODELS = {
54-
'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
55-
'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet',
56-
'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
54+
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
55+
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
56+
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
5757
}
5858

5959

@@ -205,8 +205,8 @@ def create_trt_bindings(engine, context):
205205
"is_input": True if engine.binding_is_input(name) else False
206206
}
207207
if engine.binding_is_input(name):
208-
bindings[name]['cpu_data'] = np.random.randn(
209-
*shape).astype(np.float32)
208+
bindings[name]['cpu_data'] = np.random.randn(*shape).astype(
209+
np.float32)
210210
bindings[name]['cuda_ptr'] = cuda.mem_alloc(bindings[name][
211211
'cpu_data'].nbytes)
212212
else:

deploy/third_engine/onnx/infer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323

2424
# Global dictionary
2525
SUPPORT_MODELS = {
26-
'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
27-
'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet',
28-
'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
26+
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
27+
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
28+
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
2929
}
3030

3131
parser = argparse.ArgumentParser(description=__doc__)

ppdet/engine/export_utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
# Global dictionary
3030
TRT_MIN_SUBGRAPH = {
3131
'YOLO': 10,
32+
'PPYOLOE': 10,
3233
'YOLOX': 20,
3334
'YOLOv5': 20,
3435
'RTMDet': 20,
@@ -130,7 +131,9 @@ def _dump_infer_config(config, path, image_shape, model):
130131
arch_state = True
131132
break
132133

133-
if infer_arch in ['YOLOX', 'YOLOv5', 'YOLOv6', 'YOLOv7', 'YOLOv8']:
134+
if infer_arch in [
135+
'YOLOX', 'PPYOLOE', 'YOLOv5', 'YOLOv6', 'YOLOv7', 'YOLOv8'
136+
]:
134137
infer_cfg['arch'] = infer_arch
135138
infer_cfg['min_subgraph_size'] = TRT_MIN_SUBGRAPH[infer_arch]
136139
arch_state = True

ppdet/modeling/architectures/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from . import meta_arch
1616
from . import yolo
17+
from . import ppyoloe
1718
from . import yolox
1819
from . import yolov5
1920
from . import yolov6
@@ -23,6 +24,7 @@
2324

2425
from .meta_arch import *
2526
from .yolo import *
27+
from .ppyoloe import *
2628
from .yolox import *
2729
from .yolov5 import *
2830
from .yolov6 import *
+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
from ppdet.core.workspace import register, create
20+
from .meta_arch import BaseArch
21+
22+
__all__ = ['PPYOLOE']
23+
# PP-YOLOE and PP-YOLOE+ are recommended to use this architecture
24+
# PP-YOLOE and PP-YOLOE+ can also use the same architecture of YOLOv3 in yolo.py
25+
26+
27+
@register
28+
class PPYOLOE(BaseArch):
29+
__category__ = 'architecture'
30+
__inject__ = ['post_process']
31+
32+
def __init__(self,
33+
backbone='CSPResNet',
34+
neck='CustomCSPPAN',
35+
yolo_head='PPYOLOEHead',
36+
post_process='BBoxPostProcess',
37+
for_mot=False):
38+
"""
39+
PPYOLOE network, see https://arxiv.org/abs/2203.16250
40+
41+
Args:
42+
backbone (nn.Layer): backbone instance
43+
neck (nn.Layer): neck instance
44+
yolo_head (nn.Layer): anchor_head instance
45+
post_process (object): `BBoxPostProcess` instance
46+
for_mot (bool): whether return other features for multi-object tracking
47+
models, default False in pure object detection models.
48+
"""
49+
super(PPYOLOE, self).__init__()
50+
self.backbone = backbone
51+
self.neck = neck
52+
self.yolo_head = yolo_head
53+
self.post_process = post_process
54+
self.for_mot = for_mot
55+
56+
@classmethod
57+
def from_config(cls, cfg, *args, **kwargs):
58+
# backbone
59+
backbone = create(cfg['backbone'])
60+
61+
# fpn
62+
kwargs = {'input_shape': backbone.out_shape}
63+
neck = create(cfg['neck'], **kwargs)
64+
65+
# head
66+
kwargs = {'input_shape': neck.out_shape}
67+
yolo_head = create(cfg['yolo_head'], **kwargs)
68+
69+
return {
70+
'backbone': backbone,
71+
'neck': neck,
72+
"yolo_head": yolo_head,
73+
}
74+
75+
def _forward(self):
76+
body_feats = self.backbone(self.inputs)
77+
neck_feats = self.neck(body_feats, self.for_mot)
78+
79+
if self.training:
80+
yolo_losses = self.yolo_head(neck_feats, self.inputs)
81+
return yolo_losses
82+
else:
83+
yolo_head_outs = self.yolo_head(neck_feats)
84+
if self.post_process is not None:
85+
bbox, bbox_num = self.post_process(
86+
yolo_head_outs, self.yolo_head.mask_anchors,
87+
self.inputs['im_shape'], self.inputs['scale_factor'])
88+
else:
89+
bbox, bbox_num = self.yolo_head.post_process(
90+
yolo_head_outs, self.inputs['scale_factor'])
91+
output = {'bbox': bbox, 'bbox_num': bbox_num}
92+
93+
return output
94+
95+
def get_loss(self):
96+
return self._forward()
97+
98+
def get_pred(self):
99+
return self._forward()

ppdet/modeling/architectures/yolo.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from .meta_arch import BaseArch
2121

2222
__all__ = ['YOLOv3']
23-
# PP-YOLO, PP-YOLOv2, PP-YOLOE use the same architecture as YOLOv3
23+
# YOLOv3,PP-YOLO,PP-YOLOv2,PP-YOLOE,PP-YOLOE+ use the same architecture as YOLOv3
24+
# PP-YOLOE and PP-YOLOE+ are recommended to use PPYOLOE architecture in ppyoloe.py
2425

2526

2627
@register

0 commit comments

Comments
 (0)