Skip to content

Commit

Permalink
support soft-nms and potential new nms methods
Browse files Browse the repository at this point in the history
  • Loading branch information
hellock committed Nov 26, 2018
1 parent af755e0 commit dd2907e
Show file tree
Hide file tree
Showing 14 changed files with 124 additions and 93 deletions.
5 changes: 4 additions & 1 deletion configs/cascade_mask_rcnn_r50_fpn_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05, max_per_img=100, nms_thr=0.5, mask_thr_binary=0.5),
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100,
mask_thr_binary=0.5),
keep_all_stages=False)
# dataset settings
dataset_type = 'CocoDataset'
Expand Down
3 changes: 2 additions & 1 deletion configs/cascade_rcnn_r50_fpn_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5),
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100),
keep_all_stages=False)
# dataset settings
dataset_type = 'CocoDataset'
Expand Down
5 changes: 4 additions & 1 deletion configs/fast_mask_rcnn_r50_fpn_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@
debug=False))
test_cfg = dict(
rcnn=dict(
score_thr=0.05, max_per_img=100, nms_thr=0.5, mask_thr_binary=0.5))
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100,
mask_thr_binary=0.5))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
Expand Down
4 changes: 3 additions & 1 deletion configs/fast_rcnn_r50_fpn_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@
neg_balance_thr=0),
pos_weight=-1,
debug=False))
test_cfg = dict(rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5))
test_cfg = dict(
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
Expand Down
6 changes: 5 additions & 1 deletion configs/faster_rcnn_r50_fpn_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@
max_num=2000,
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(score_thr=0.05, max_per_img=100, nms_thr=0.5))
rcnn=dict(
score_thr=0.05, nms=dict(type='nms', iou_thr=0.5), max_per_img=100)
# soft-nms is also supported for rcnn testing
# e.g., nms=dict(type='soft_nms', iou_thr=0.5, min_score=0.05)
)
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
Expand Down
5 changes: 4 additions & 1 deletion configs/mask_rcnn_r50_fpn_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@
nms_thr=0.7,
min_bbox_size=0),
rcnn=dict(
score_thr=0.05, max_per_img=100, nms_thr=0.5, mask_thr_binary=0.5))
score_thr=0.05,
nms=dict(type='nms', iou_thr=0.5),
max_per_img=100,
mask_thr_binary=0.5))
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
Expand Down
13 changes: 7 additions & 6 deletions mmdet/core/post_processing/bbox_nms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch

from mmdet.ops import nms
from mmdet.ops.nms import nms_wrapper


def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_thr, max_num=-1):
def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_cfg, max_num=-1):
"""NMS for multi-class bboxes.
Args:
Expand All @@ -21,6 +21,9 @@ def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_thr, max_num=-1):
"""
num_classes = multi_scores.shape[1]
bboxes, labels = [], []
nms_cfg_ = nms_cfg.copy()
nms_type = nms_cfg_.pop('type', 'nms')
nms_op = getattr(nms_wrapper, nms_type)
for i in range(1, num_classes):
cls_inds = multi_scores[:, i] > score_thr
if not cls_inds.any():
Expand All @@ -32,11 +35,9 @@ def multiclass_nms(multi_bboxes, multi_scores, score_thr, nms_thr, max_num=-1):
_bboxes = multi_bboxes[cls_inds, i * 4:(i + 1) * 4]
_scores = multi_scores[cls_inds, i]
cls_dets = torch.cat([_bboxes, _scores[:, None]], dim=1)
# perform nms
nms_keep = nms(cls_dets, nms_thr)
cls_dets = cls_dets[nms_keep, :]
cls_dets, _ = nms_op(cls_dets, **nms_cfg_)
cls_labels = multi_bboxes.new_full(
(len(nms_keep), ), i - 1, dtype=torch.long)
(cls_dets.shape[0], ), i - 1, dtype=torch.long)
bboxes.append(cls_dets)
labels.append(cls_labels)
if bboxes:
Expand Down
4 changes: 1 addition & 3 deletions mmdet/core/post_processing/merge_augs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def merge_aug_proposals(aug_proposals, img_metas, rpn_test_cfg):
scale_factor, flip)
recovered_proposals.append(_proposals)
aug_proposals = torch.cat(recovered_proposals, dim=0)
nms_keep = nms(aug_proposals, rpn_test_cfg.nms_thr,
aug_proposals.get_device())
merged_proposals = aug_proposals[nms_keep, :]
merged_proposals, _ = nms(aug_proposals, rpn_test_cfg.nms_thr)
scores = merged_proposals[:, 4]
_, order = scores.sort(0, descending=True)
num = min(rpn_test_cfg.max_num, merged_proposals.shape[0])
Expand Down
7 changes: 3 additions & 4 deletions mmdet/models/bbox_heads/bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_det_bboxes(self,
img_shape,
scale_factor,
rescale=False,
nms_cfg=None):
cfg=None):
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
scores = F.softmax(cls_score, dim=1) if cls_score is not None else None
Expand All @@ -115,12 +115,11 @@ def get_det_bboxes(self,
if rescale:
bboxes /= scale_factor

if nms_cfg is None:
if cfg is None:
return bboxes, scores
else:
det_bboxes, det_labels = multiclass_nms(
bboxes, scores, nms_cfg.score_thr, nms_cfg.nms_thr,
nms_cfg.max_per_img)
bboxes, scores, cfg.score_thr, cfg.nms, cfg.max_per_img)

return det_bboxes, det_labels

Expand Down
4 changes: 2 additions & 2 deletions mmdet/models/detectors/cascade_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def simple_test(self, img, img_meta, proposals=None, rescale=False):
img_shape,
scale_factor,
rescale=rescale,
nms_cfg=rcnn_test_cfg)
cfg=rcnn_test_cfg)
bbox_result = bbox2result(det_bboxes, det_labels,
bbox_head.num_classes)
ms_bbox_result['stage{}'.format(i)] = bbox_result
Expand Down Expand Up @@ -256,7 +256,7 @@ def simple_test(self, img, img_meta, proposals=None, rescale=False):
img_shape,
scale_factor,
rescale=rescale,
nms_cfg=rcnn_test_cfg)
cfg=rcnn_test_cfg)
bbox_result = bbox2result(det_bboxes, det_labels,
self.bbox_head[-1].num_classes)
ms_bbox_result['ensemble'] = bbox_result
Expand Down
10 changes: 5 additions & 5 deletions mmdet/models/detectors/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def simple_test_bboxes(self,
img_shape,
scale_factor,
rescale=rescale,
nms_cfg=rcnn_test_cfg)
cfg=rcnn_test_cfg)
return det_bboxes, det_labels

def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
Expand All @@ -73,15 +73,15 @@ def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
img_shape,
scale_factor,
rescale=False,
nms_cfg=None)
cfg=None)
aug_bboxes.append(bboxes)
aug_scores.append(scores)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, self.test_cfg.rcnn)
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
det_bboxes, det_labels = multiclass_nms(
merged_bboxes, merged_scores, self.test_cfg.rcnn.score_thr,
self.test_cfg.rcnn.nms_thr, self.test_cfg.rcnn.max_per_img)
merged_bboxes, merged_scores, rcnn_test_cfg.score_thr,
rcnn_test_cfg.nms, rcnn_test_cfg.max_per_img)
return det_bboxes, det_labels


Expand Down
8 changes: 4 additions & 4 deletions mmdet/models/rpn_heads/rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,13 @@ def _get_proposals_single(self, rpn_cls_scores, rpn_bbox_preds,
proposals = proposals[valid_inds, :]
scores = scores[valid_inds]
proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
nms_keep = nms(proposals, cfg.nms_thr)[:cfg.nms_post]
proposals = proposals[nms_keep, :]
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.nms_post, :]
mlvl_proposals.append(proposals)
proposals = torch.cat(mlvl_proposals, 0)
if cfg.nms_across_levels:
nms_keep = nms(proposals, cfg.nms_thr)[:cfg.max_num]
proposals = proposals[nms_keep, :]
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.max_num, :]
else:
scores = proposals[:, 4]
_, order = scores.sort(0, descending=True)
Expand Down
80 changes: 41 additions & 39 deletions mmdet/ops/nms/cpu_soft_nms.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) University of Maryland, College Park
# Licensed under The MIT License [see LICENSE for details]
# Written by Navaneeth Bodla and Bharat Singh
# Modified by Kai Chen
# ----------------------------------------------------------

import numpy as np
Expand All @@ -15,12 +16,13 @@ cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
return a if a <= b else b


def cpu_soft_nms(
np.ndarray[float, ndim=2] boxes_in,
float iou_thr,
unsigned int method=1,
float sigma=0.5,
float Nt=0.3,
float threshold=0.001,
unsigned int method=0
float min_score=0.001,
):
boxes = boxes_in.copy()
cdef unsigned int N = boxes.shape[0]
Expand All @@ -36,11 +38,11 @@ def cpu_soft_nms(
maxscore = boxes[i, 4]
maxpos = i

tx1 = boxes[i,0]
ty1 = boxes[i,1]
tx2 = boxes[i,2]
ty2 = boxes[i,3]
ts = boxes[i,4]
tx1 = boxes[i, 0]
ty1 = boxes[i, 1]
tx2 = boxes[i, 2]
ty2 = boxes[i, 3]
ts = boxes[i, 4]
ti = inds[i]

pos = i + 1
Expand All @@ -52,26 +54,26 @@ def cpu_soft_nms(
pos = pos + 1

# add max box as a detection
boxes[i,0] = boxes[maxpos,0]
boxes[i,1] = boxes[maxpos,1]
boxes[i,2] = boxes[maxpos,2]
boxes[i,3] = boxes[maxpos,3]
boxes[i,4] = boxes[maxpos,4]
boxes[i, 0] = boxes[maxpos, 0]
boxes[i, 1] = boxes[maxpos, 1]
boxes[i, 2] = boxes[maxpos, 2]
boxes[i, 3] = boxes[maxpos, 3]
boxes[i, 4] = boxes[maxpos, 4]
inds[i] = inds[maxpos]

# swap ith box with position of max box
boxes[maxpos,0] = tx1
boxes[maxpos,1] = ty1
boxes[maxpos,2] = tx2
boxes[maxpos,3] = ty2
boxes[maxpos,4] = ts
boxes[maxpos, 0] = tx1
boxes[maxpos, 1] = ty1
boxes[maxpos, 2] = tx2
boxes[maxpos, 3] = ty2
boxes[maxpos, 4] = ts
inds[maxpos] = ti

tx1 = boxes[i,0]
ty1 = boxes[i,1]
tx2 = boxes[i,2]
ty2 = boxes[i,3]
ts = boxes[i,4]
tx1 = boxes[i, 0]
ty1 = boxes[i, 1]
tx2 = boxes[i, 2]
ty2 = boxes[i, 3]
ts = boxes[i, 4]

pos = i + 1
# NMS iterations, note that N changes if detection boxes fall below
Expand All @@ -89,35 +91,35 @@ def cpu_soft_nms(
ih = (min(ty2, y2) - max(ty1, y1) + 1)
if ih > 0:
ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih)
ov = iw * ih / ua #iou between max box and detection box
ov = iw * ih / ua # iou between max box and detection box

if method == 1: # linear
if ov > Nt:
if method == 1: # linear
if ov > iou_thr:
weight = 1 - ov
else:
weight = 1
elif method == 2: # gaussian
weight = np.exp(-(ov * ov)/sigma)
else: # original NMS
if ov > Nt:
elif method == 2: # gaussian
weight = np.exp(-(ov * ov) / sigma)
else: # original NMS
if ov > iou_thr:
weight = 0
else:
weight = 1

boxes[pos, 4] = weight*boxes[pos, 4]
boxes[pos, 4] = weight * boxes[pos, 4]

# if box score falls below threshold, discard the box by
# swapping with last box update N
if boxes[pos, 4] < threshold:
boxes[pos,0] = boxes[N-1, 0]
boxes[pos,1] = boxes[N-1, 1]
boxes[pos,2] = boxes[N-1, 2]
boxes[pos,3] = boxes[N-1, 3]
boxes[pos,4] = boxes[N-1, 4]
inds[pos] = inds[N-1]
if boxes[pos, 4] < min_score:
boxes[pos, 0] = boxes[N-1, 0]
boxes[pos, 1] = boxes[N-1, 1]
boxes[pos, 2] = boxes[N-1, 2]
boxes[pos, 3] = boxes[N-1, 3]
boxes[pos, 4] = boxes[N-1, 4]
inds[pos] = inds[N - 1]
N = N - 1
pos = pos - 1

pos = pos + 1

return boxes[:N], inds[:N]
return boxes[:N], inds[:N]
Loading

0 comments on commit dd2907e

Please sign in to comment.