Skip to content

Commit

Permalink
Fix one bug in sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
OceanPang committed Dec 10, 2019
1 parent 72ea319 commit 38e2cde
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion configs/libra_rcnn/libra_faster_rcnn_r101_fpn_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
pos_sampler=dict(type='InstanceBalancedPosSampler'),
neg_sampler=dict(
type='IoUBalancedNegSampler',
floor_thr=0,
floor_thr=-1,
floor_fraction=0,
num_bins=3)),
pos_weight=-1,
Expand Down
2 changes: 1 addition & 1 deletion configs/libra_rcnn/libra_faster_rcnn_r50_fpn_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
pos_sampler=dict(type='InstanceBalancedPosSampler'),
neg_sampler=dict(
type='IoUBalancedNegSampler',
floor_thr=0,
floor_thr=-1,
floor_fraction=0,
num_bins=3)),
pos_weight=-1,
Expand Down
2 changes: 1 addition & 1 deletion configs/libra_rcnn/libra_faster_rcnn_x101_64x4d_fpn_1x.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
pos_sampler=dict(type='InstanceBalancedPosSampler'),
neg_sampler=dict(
type='IoUBalancedNegSampler',
floor_thr=0,
floor_thr=-1,
floor_fraction=0,
num_bins=3)),
pos_weight=-1,
Expand Down
5 changes: 3 additions & 2 deletions mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ class IoUBalancedNegSampler(RandomSampler):
def __init__(self,
num,
pos_fraction,
floor_thr=0,
floor_thr=-1,
floor_fraction=0,
num_bins=3,
**kwargs):
super(IoUBalancedNegSampler, self).__init__(num, pos_fraction,
**kwargs)
assert floor_thr >= 0
assert floor_thr >= 0 or floor_thr == -1
assert 0 <= floor_fraction <= 1
assert num_bins >= 1

Expand Down Expand Up @@ -98,6 +98,7 @@ def _sample_neg(self, assign_result, num_expected, **kwargs):
floor_set = set()
iou_sampling_set = set(
np.where(max_overlaps > self.floor_thr)[0])
self.floor_thr == 0

floor_neg_inds = list(floor_set & neg_set)
iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
Expand Down

0 comments on commit 38e2cde

Please sign in to comment.