From 38e2cdec48a96e01ded8666567a305bb4e865970 Mon Sep 17 00:00:00 2001 From: pangjm Date: Mon, 9 Dec 2019 21:00:04 -0800 Subject: [PATCH] Fix one bug in sampling --- configs/libra_rcnn/libra_faster_rcnn_r101_fpn_1x.py | 2 +- configs/libra_rcnn/libra_faster_rcnn_r50_fpn_1x.py | 2 +- configs/libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x.py | 2 +- mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py | 5 +++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/configs/libra_rcnn/libra_faster_rcnn_r101_fpn_1x.py b/configs/libra_rcnn/libra_faster_rcnn_r101_fpn_1x.py index 6797539..7a31476 100644 --- a/configs/libra_rcnn/libra_faster_rcnn_r101_fpn_1x.py +++ b/configs/libra_rcnn/libra_faster_rcnn_r101_fpn_1x.py @@ -97,7 +97,7 @@ pos_sampler=dict(type='InstanceBalancedPosSampler'), neg_sampler=dict( type='IoUBalancedNegSampler', - floor_thr=0, + floor_thr=-1, floor_fraction=0, num_bins=3)), pos_weight=-1, diff --git a/configs/libra_rcnn/libra_faster_rcnn_r50_fpn_1x.py b/configs/libra_rcnn/libra_faster_rcnn_r50_fpn_1x.py index 473c9ef..d596903 100644 --- a/configs/libra_rcnn/libra_faster_rcnn_r50_fpn_1x.py +++ b/configs/libra_rcnn/libra_faster_rcnn_r50_fpn_1x.py @@ -97,7 +97,7 @@ pos_sampler=dict(type='InstanceBalancedPosSampler'), neg_sampler=dict( type='IoUBalancedNegSampler', - floor_thr=0, + floor_thr=-1, floor_fraction=0, num_bins=3)), pos_weight=-1, diff --git a/configs/libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x.py b/configs/libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x.py index 5872c4e..26172a4 100644 --- a/configs/libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x.py +++ b/configs/libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x.py @@ -99,7 +99,7 @@ pos_sampler=dict(type='InstanceBalancedPosSampler'), neg_sampler=dict( type='IoUBalancedNegSampler', - floor_thr=0, + floor_thr=-1, floor_fraction=0, num_bins=3)), pos_weight=-1, diff --git a/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py b/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py index 1a19dc7..b20d476 100644 --- a/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py +++ b/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py @@ -27,13 +27,13 @@ class IoUBalancedNegSampler(RandomSampler): def __init__(self, num, pos_fraction, - floor_thr=0, + floor_thr=-1, floor_fraction=0, num_bins=3, **kwargs): super(IoUBalancedNegSampler, self).__init__(num, pos_fraction, **kwargs) - assert floor_thr >= 0 + assert floor_thr >= 0 or floor_thr == -1 assert 0 <= floor_fraction <= 1 assert num_bins >= 1 @@ -98,6 +98,7 @@ def _sample_neg(self, assign_result, num_expected, **kwargs): floor_set = set() iou_sampling_set = set( np.where(max_overlaps > self.floor_thr)[0]) + self.floor_thr == 0 floor_neg_inds = list(floor_set & neg_set) iou_sampling_neg_inds = list(iou_sampling_set & neg_set)