Skip to content

Commit 4f96dc2

Browse files
authoredJun 28, 2021
add solov2 enhance model (PaddlePaddle#3517)
* add solov2 enhance model
1 parent 5f9b0bc commit 4f96dc2

File tree

6 files changed

+165
-43
lines changed

6 files changed

+165
-43
lines changed
 

‎configs/solov2/README.md

+14
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ SOLOv2 (Segmenting Objects by Locations) is a fast instance segmentation framewo
2727

2828
- SOLOv2 is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.
2929

30+
## Enhanced model
31+
| Backbone | Input size | Lr schd | V100 FP32(FPS) | Mask AP<sup>val</sup> | Download | Configs |
32+
| :---------------------: | :-------------------: | :-----: | :------------: | :-----: | :---------: | :------------------------: |
33+
| Light-R50-VD-DCN-FPN | 512 | 3x | 38.6 | 39.0 | [model](https://paddledet.bj.bcebos.com/models/solov2_r50_enhance_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/solov2/solov2_r50_enhance_coco.yml) |
34+
35+
**Optimizing method of enhanced model:**
36+
- Better backbone network: ResNet50vd-DCN
37+
- A better pre-training model for knowledge distillation
38+
- [Exponential Moving Average](https://www.investopedia.com/terms/e/ema.asp)
39+
- Synchronized Batch Normalization
40+
- Multi-scale training
41+
- More data augmentation methods
42+
- DropBlock
43+
3044
## Citations
3145
```
3246
@article{wang2020solov2,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
worker_num: 2
2+
TrainReader:
3+
sample_transforms:
4+
- Decode: {}
5+
- Poly2Mask: {}
6+
- RandomDistort: {}
7+
- RandomCrop: {}
8+
- RandomResize: {interp: 1,
9+
target_size: [[352, 852], [384, 852], [416, 852], [448, 852], [480, 852], [512, 852]],
10+
keep_ratio: True}
11+
- RandomFlip: {}
12+
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
13+
- Permute: {}
14+
batch_transforms:
15+
- PadBatch: {pad_to_stride: 32}
16+
- Gt2Solov2Target: {num_grids: [40, 36, 24, 16, 12],
17+
scale_ranges: [[1, 96], [48, 192], [96, 384], [192, 768], [384, 2048]],
18+
coord_sigma: 0.2}
19+
batch_size: 2
20+
shuffle: true
21+
drop_last: true
22+
23+
24+
EvalReader:
25+
sample_transforms:
26+
- Decode: {}
27+
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
28+
- Resize: {interp: 1, target_size: [512, 852], keep_ratio: True}
29+
- Permute: {}
30+
batch_transforms:
31+
- PadBatch: {pad_to_stride: 32}
32+
batch_size: 1
33+
shuffle: false
34+
drop_last: false
35+
36+
37+
TestReader:
38+
sample_transforms:
39+
- Decode: {}
40+
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
41+
- Resize: {interp: 1, target_size: [512, 852], keep_ratio: True}
42+
- Permute: {}
43+
batch_transforms:
44+
- PadBatch: {pad_to_stride: 32}
45+
batch_size: 1
46+
shuffle: false
47+
drop_last: false
+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
_BASE_: [
2+
'../datasets/coco_instance.yml',
3+
'../runtime.yml',
4+
'_base_/solov2_r50_fpn.yml',
5+
'_base_/optimizer_1x.yml',
6+
'_base_/solov2_light_reader.yml',
7+
]
8+
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams
9+
weights: output/solov2_r50_fpn_3x_coco/model_final
10+
epoch: 36
11+
use_ema: true
12+
ema_decay: 0.9998
13+
14+
ResNet:
15+
depth: 50
16+
variant: d
17+
freeze_at: 0
18+
freeze_norm: false
19+
norm_type: sync_bn
20+
return_idx: [0,1,2,3]
21+
dcn_v2_stages: [1,2,3]
22+
lr_mult_list: [0.05, 0.05, 0.1, 0.15]
23+
num_stages: 4
24+
25+
SOLOv2Head:
26+
seg_feat_channels: 256
27+
stacked_convs: 3
28+
num_grids: [40, 36, 24, 16, 12]
29+
kernel_out_channels: 128
30+
solov2_loss: SOLOv2Loss
31+
mask_nms: MaskMatrixNMS
32+
dcn_v2_stages: [2]
33+
drop_block: True
34+
35+
SOLOv2MaskHead:
36+
mid_channels: 128
37+
out_channels: 128
38+
start_level: 0
39+
end_level: 3
40+
use_dcn_in_tower: True
41+
42+
LearningRate:
43+
base_lr: 0.01
44+
schedulers:
45+
- !PiecewiseDecay
46+
gamma: 0.1
47+
milestones: [24, 33]
48+
- !LinearWarmup
49+
start_factor: 0.
50+
steps: 1000

‎ppdet/modeling/heads/solov2_head.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import paddle.nn.functional as F
2323
from paddle.nn.initializer import Normal, Constant
2424

25-
from ppdet.modeling.layers import ConvNormLayer, MaskMatrixNMS
25+
from ppdet.modeling.layers import ConvNormLayer, MaskMatrixNMS, DropBlock
2626
from ppdet.core.workspace import register
2727

2828
from six.moves import zip
@@ -182,7 +182,8 @@ def __init__(self,
182182
score_threshold=0.1,
183183
mask_threshold=0.5,
184184
mask_nms=None,
185-
norm_type='gn'):
185+
norm_type='gn',
186+
drop_block=False):
186187
super(SOLOv2Head, self).__init__()
187188
self.num_classes = num_classes
188189
self.in_channels = in_channels
@@ -198,6 +199,7 @@ def __init__(self,
198199
self.score_threshold = score_threshold
199200
self.mask_threshold = mask_threshold
200201
self.norm_type = norm_type
202+
self.drop_block = drop_block
201203

202204
self.kernel_pred_convs = []
203205
self.cate_pred_convs = []
@@ -250,6 +252,10 @@ def __init__(self,
250252
bias_attr=ParamAttr(initializer=Constant(
251253
value=float(-np.log((1 - 0.01) / 0.01))))))
252254

255+
if self.drop_block:
256+
self.drop_block_fun = DropBlock(
257+
block_size=3, keep_prob=0.9, name='solo_cate.dropblock')
258+
253259
def _points_nms(self, heat, kernel_size=2):
254260
hmax = F.max_pool2d(heat, kernel_size=kernel_size, stride=1, padding=1)
255261
keep = paddle.cast((hmax[:, :, :-1, :-1] == heat), 'float32')
@@ -318,10 +324,14 @@ def _get_output_single(self, input, idx):
318324

319325
for kernel_layer in self.kernel_pred_convs:
320326
kernel_feat = F.relu(kernel_layer(kernel_feat))
327+
if self.drop_block:
328+
kernel_feat = self.drop_block_fun(kernel_feat)
321329
kernel_pred = self.solo_kernel(kernel_feat)
322330
# cate branch
323331
for cate_layer in self.cate_pred_convs:
324332
cate_feat = F.relu(cate_layer(cate_feat))
333+
if self.drop_block:
334+
cate_feat = self.drop_block_fun(cate_feat)
325335
cate_pred = self.solo_cate(cate_feat)
326336

327337
if not self.training:

‎ppdet/modeling/layers.py

+41
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,47 @@ def forward(self, inputs):
250250
return out
251251

252252

253+
class DropBlock(nn.Layer):
254+
def __init__(self, block_size, keep_prob, name, data_format='NCHW'):
255+
"""
256+
DropBlock layer, see https://arxiv.org/abs/1810.12890
257+
258+
Args:
259+
block_size (int): block size
260+
keep_prob (int): keep probability
261+
name (str): layer name
262+
data_format (str): data format, NCHW or NHWC
263+
"""
264+
super(DropBlock, self).__init__()
265+
self.block_size = block_size
266+
self.keep_prob = keep_prob
267+
self.name = name
268+
self.data_format = data_format
269+
270+
def forward(self, x):
271+
if not self.training or self.keep_prob == 1:
272+
return x
273+
else:
274+
gamma = (1. - self.keep_prob) / (self.block_size**2)
275+
if self.data_format == 'NCHW':
276+
shape = x.shape[2:]
277+
else:
278+
shape = x.shape[1:3]
279+
for s in shape:
280+
gamma *= s / (s - self.block_size + 1)
281+
282+
matrix = paddle.cast(paddle.rand(x.shape, x.dtype) < gamma, x.dtype)
283+
mask_inv = F.max_pool2d(
284+
matrix,
285+
self.block_size,
286+
stride=1,
287+
padding=self.block_size // 2,
288+
data_format=self.data_format)
289+
mask = 1. - mask_inv
290+
y = x * mask * (mask.numel() / mask.sum())
291+
return y
292+
293+
253294
@register
254295
@serializable
255296
class AnchorGeneratorSSD(object):

‎ppdet/modeling/necks/yolo_fpn.py

+1-41
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import paddle.nn.functional as F
1818
from paddle import ParamAttr
1919
from ppdet.core.workspace import register, serializable
20+
from ppdet.modeling.layers import DropBlock
2021
from ..backbones.darknet import ConvBNLayer
2122
from ..shape_spec import ShapeSpec
2223

@@ -173,47 +174,6 @@ def forward(self, x):
173174
return y
174175

175176

176-
class DropBlock(nn.Layer):
177-
def __init__(self, block_size, keep_prob, name, data_format='NCHW'):
178-
"""
179-
DropBlock layer, see https://arxiv.org/abs/1810.12890
180-
181-
Args:
182-
block_size (int): block size
183-
keep_prob (int): keep probability
184-
name (str): layer name
185-
data_format (str): data format, NCHW or NHWC
186-
"""
187-
super(DropBlock, self).__init__()
188-
self.block_size = block_size
189-
self.keep_prob = keep_prob
190-
self.name = name
191-
self.data_format = data_format
192-
193-
def forward(self, x):
194-
if not self.training or self.keep_prob == 1:
195-
return x
196-
else:
197-
gamma = (1. - self.keep_prob) / (self.block_size**2)
198-
if self.data_format == 'NCHW':
199-
shape = x.shape[2:]
200-
else:
201-
shape = x.shape[1:3]
202-
for s in shape:
203-
gamma *= s / (s - self.block_size + 1)
204-
205-
matrix = paddle.cast(paddle.rand(x.shape, x.dtype) < gamma, x.dtype)
206-
mask_inv = F.max_pool2d(
207-
matrix,
208-
self.block_size,
209-
stride=1,
210-
padding=self.block_size // 2,
211-
data_format=self.data_format)
212-
mask = 1. - mask_inv
213-
y = x * mask * (mask.numel() / mask.sum())
214-
return y
215-
216-
217177
class CoordConv(nn.Layer):
218178
def __init__(self,
219179
ch_in,

0 commit comments

Comments
 (0)
Please sign in to comment.