Skip to content

Commit

Permalink
[Fix] Evaluation bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
SakiRinn committed Mar 17, 2023
1 parent b8a52e2 commit 17b1b9b
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 58 deletions.
48 changes: 25 additions & 23 deletions cntcocotools/cocoeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ def _toMask(anns, coco):
ann['segmentation'] = rle
p = self.params

kwargs = dict(imgIds=p.imgIds)
if p.useCats:
kwargs.update(catIds=p.catIds)
gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(**kwargs))
dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(**kwargs))
gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
else:
gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))

if p.iouType == 'segm':
_toMask(gts, self.cocoGt)
Expand Down Expand Up @@ -83,24 +84,23 @@ def evaluate(self):
computeIoU = self.computeIoU
elif p.iouType == 'keypoints':
computeIoU = self.computeOks
self.ious = {(imgId, catId): computeIoU(imgId, catId)
self.ious = {(imgId, catId): computeIoU(imgId, catId) \
for imgId in p.imgIds
for catId in catIds}

evaluateImg = self.evaluateImg
maxDet = p.maxDets[-1]
self.evalImgs = [evaluateImg(imgId, catId, areaRng, maxDet)
for catId in catIds
for areaRng in p.areaRng
for imgId in p.imgIds
]
self.evalImgs = [evaluateImg(imgId, catId, areaRng, maxDet) \
for catId in catIds
for areaRng in p.areaRng
for imgId in p.imgIds]
self._paramsEval = copy.deepcopy(self.params)
toc = time.time()
print('DONE (t={:0.2f}s).'.format(toc-tic))
print('DONE (t={:0.2f}s).'.format(toc - tic))

def computeIoU(self, imgId, catId):
p = self.params
if p.useCats and p.useCnts:
if p.useCats:
gt = self._gts[imgId, catId]
dt = self._dts[imgId, catId]
else:
Expand Down Expand Up @@ -160,8 +160,8 @@ def computeOks(self, imgId, catId):
dy = yd - yg
else:
z = np.zeros((k))
dx = np.max((z, x0 - xd), axis=0)+np.max((z, xd - x1), axis=0)
dy = np.max((z, y0 - yd), axis=0)+np.max((z, yd - y1), axis=0)
dx = np.max((z, x0 - xd), axis=0) + np.max((z, xd - x1), axis=0)
dy = np.max((z, y0 - yd), axis=0) + np.max((z, yd - y1), axis=0)
e = (dx**2 + dy**2) / vars / (gt['area'] + np.spacing(1)) / 2
if k1 > 0:
e=e[vg > 0]
Expand All @@ -170,7 +170,7 @@ def computeOks(self, imgId, catId):

def evaluateImg(self, imgId, catId, aRng, maxDet):
p = self.params
if p.useCats and p.useCnts:
if p.useCats:
gt = self._gts[imgId, catId]
dt = self._dts[imgId, catId]
else:
Expand Down Expand Up @@ -207,7 +207,6 @@ def evaluateImg(self, imgId, catId, aRng, maxDet):
for tind, t in enumerate(p.iouThrs):
for ctind, ct in enumerate(p.acThrs):
for dind, d in enumerate(dt):

iou = min([t , 1 - 1e-10])
ac = min([ct, 1 - 1e-10])
m = -1
Expand Down Expand Up @@ -354,9 +353,9 @@ def _summarize(ap=1, iouThr=None, acThr=None, areaRng='all', maxDets=100):
p = self.params
iStr = ' {:<18} {} @[ IoU={:<9} | AC={:<9} | area={:>6s} | maxDets={:>4d} ] = {:0.3f}'
titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
typeStr = '(AP)' if ap==1 else '(AR)'
typeStr = '(AP)' if ap == 1 else '(AR)'
iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
if iouThr is None else '{:0.2f}'.format(acThr)
if iouThr is None else '{:0.2f}'.format(iouThr)
acStr = '{:0.2f}:{:0.2f}'.format(p.acThrs[0], p.acThrs[-1]) \
if acThr is None else '{:0.2f}'.format(acThr)

Expand All @@ -366,7 +365,7 @@ def _summarize(ap=1, iouThr=None, acThr=None, areaRng='all', maxDets=100):
s = self.eval['precision']
if iouThr is not None:
t = np.where(iouThr == p.iouThrs)[0]
s = s[t, ...]
s = s[t]
if acThr is not None:
ct = np.where(acThr == p.acThrs)[0]
s = s[:, ct, ...]
Expand All @@ -376,17 +375,20 @@ def _summarize(ap=1, iouThr=None, acThr=None, areaRng='all', maxDets=100):
if iouThr is not None:
t = np.where(iouThr == p.iouThrs)[0]
s = s[t]
if acThr is not None:
ct = np.where(acThr == p.acThrs)[0]
s = s[:, ct, ...]
s = s[:, :, :, aind, mind]
if len(s[s > -1])==0:
if len(s[s > -1]) == 0:
mean_s = -1
else:
mean_s = np.mean(s[s > -1])
print(iStr.format(titleStr, typeStr, iouStr, acStr, areaRng, maxDets, mean_s))
return mean_s

def _summarizeDets():
stats = np.zeros((12,))
stats[0] = _summarize(1)
stats = np.zeros((12, ))
stats[0] = _summarize(1, maxDets=self.params.maxDets[0])
stats[1] = _summarize(1, iouThr=.5, acThr=.5, maxDets=self.params.maxDets[2])
stats[2] = _summarize(1, iouThr=.75, acThr=.75, maxDets=self.params.maxDets[2])
stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
Expand All @@ -401,7 +403,7 @@ def _summarizeDets():
return stats

def _summarizeKps():
stats = np.zeros((10,))
stats = np.zeros((10, ))
stats[0] = _summarize(1, maxDets=20)
stats[1] = _summarize(1, maxDets=20, iouThr=.5)
stats[2] = _summarize(1, maxDets=20, iouThr=.75)
Expand Down
15 changes: 2 additions & 13 deletions configs/_base_/models/cascade_rcnn_r50_fpn_locount.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,7 @@
bbox_head=[
dict(
type='FCBBoxHeadWithCount',
num_shared_convs=1,
num_shared_fcs=2,
num_cls_fcs=1,
num_reg_fcs=1,
num_cnt_fcs=1,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
Expand All @@ -78,11 +74,7 @@
loss_weight=1.0)),
dict(
type='FCBBoxHeadWithCount',
num_shared_convs=1,
num_shared_fcs=2,
num_cls_fcs=1,
num_reg_fcs=1,
num_cnt_fcs=1,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
Expand All @@ -107,11 +99,7 @@
loss_weight=1.0)),
dict(
type='FCBBoxHeadWithCount',
num_shared_convs=1,
num_shared_fcs=2,
num_cls_fcs=1,
num_reg_fcs=1,
num_cnt_fcs=1,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
Expand Down Expand Up @@ -209,7 +197,8 @@
pos_weight=-1,
debug=False)
],
stage_loss_weights=[1.0, 1.0, 1.0]), #Fixme: Very important parameters, 2020/04/03 [1, 0.5, 0.25]==>[1.0, 1.0, 1.0]
stage_loss_weights=[1.0, 0.5, 0.25],
stage_cnt_loss_weights=[0.1, 0.1, 0.1]),
test_cfg=dict(
rpn=dict(
nms_pre=1000,
Expand Down
32 changes: 14 additions & 18 deletions mmdet/models/roi_heads/bbox_heads/bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def __init__(self,
if self.with_cnt:
self.init_cfg += [
dict(
type='Normal', std=0.001, override=dict(name='fc_cnt'))
type='Normal', std=0.01, override=dict(name='fc_cnt'))
]

@property
Expand Down Expand Up @@ -823,10 +823,6 @@ def loss(self,
reduction_override=None):
losses = dict()

learning_bbox_weights = 1. * pow(2, -self.current_stage)
learning_cls_weights = 1. * pow(2, -self.current_stage)
learning_cnt_weights = 1. * pow(2, -self.num_stages)

# bbox
if bbox_pred is not None:
bg_class_ind = self.num_classes
Expand All @@ -842,7 +838,7 @@ def loss(self,
bbox_pred.size(0), -1,
4)[pos_inds.type(torch.bool),
labels[pos_inds.type(torch.bool)]]
losses['loss_bbox'] = learning_bbox_weights * self.loss_bbox(
losses['loss_bbox'] = self.loss_bbox(
pos_bbox_pred,
bbox_targets[pos_inds.type(torch.bool)],
bbox_weights[pos_inds.type(torch.bool)],
Expand All @@ -854,7 +850,7 @@ def loss(self,
if cls_score is not None:
avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
if cls_score.numel() > 0:
loss_cls_ = learning_cls_weights * self.loss_cls(
loss_cls_ = self.loss_cls(
cls_score,
labels,
label_weights,
Expand All @@ -873,7 +869,7 @@ def loss(self,
if cnt_score is not None:
avg_cnt_factor = max(torch.sum(count_weights > 0).float().item(), 1.)
if cnt_score.numel() > 0:
loss_cnt_ = learning_cnt_weights * self.loss_cnt(
loss_cnt_ = self.loss_cnt(
cnt_score,
counts,
count_weights,
Expand Down Expand Up @@ -942,16 +938,16 @@ def get_bboxes(self,
det_bboxes = torch.cat([det_bboxes, cnt_scores.unsqueeze(-1)], -1)
return det_bboxes, det_labels, det_counts

def init_weights(self):
if self.with_cls:
nn.init.normal_(self.fc_cls.weight, 0, 0.01)
nn.init.constant_(self.fc_cls.bias, 0)
if self.with_reg:
nn.init.normal_(self.fc_reg.weight, 0, 0.001)
nn.init.constant_(self.fc_reg.bias, 0)
if self.with_cnt:
nn.init.normal_(self.fc_cnt.weight, 0, 0.001)
nn.init.constant_(self.fc_cnt.bias, 0)
# def init_weights(self):
# if self.with_cls:
# nn.init.normal_(self.fc_cls.weight, 0, 0.01)
# nn.init.constant_(self.fc_cls.bias, 0)
# if self.with_reg:
# nn.init.normal_(self.fc_reg.weight, 0, 0.001)
# nn.init.constant_(self.fc_reg.bias, 0)
# if self.with_cnt:
# nn.init.normal_(self.fc_cnt.weight, 0, 0.001)
# nn.init.constant_(self.fc_cnt.bias, 0)

def div_counts(self, counts):
if not isinstance(counts, torch.Tensor):
Expand Down
15 changes: 11 additions & 4 deletions mmdet/models/roi_heads/cascade_roi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ class CascadeRoIHeadWithCount(BBoxTestMixinWithCount, CascadeRoIHead):
def __init__(self,
num_stages,
stage_loss_weights,
stage_cnt_loss_weights=[0.1, 0.1, 0.1],
bbox_roi_extractor=None,
bbox_head=None,
mask_roi_extractor=None,
Expand All @@ -653,6 +654,7 @@ def __init__(self,

self.num_stages = num_stages
self.stage_loss_weights = stage_loss_weights
self.stage_cnt_loss_weights = stage_cnt_loss_weights
super(CascadeRoIHead, self).__init__(
bbox_roi_extractor=bbox_roi_extractor,
bbox_head=bbox_head,
Expand Down Expand Up @@ -743,6 +745,7 @@ def forward_train(self,
self.current_stage = i
rcnn_train_cfg = self.train_cfg[i]
lw = self.stage_loss_weights[i]
lw_cnt = self.stage_loss_weights[i]

sampling_results = []
if self.with_bbox or self.with_mask:
Expand Down Expand Up @@ -770,17 +773,21 @@ def forward_train(self,
rcnn_train_cfg)

for name, value in bbox_results['loss_bbox'].items():
losses[f's{i}.{name}'] = (
value * lw if 'loss' in name else value)
if name.endswith('_cnt'):
losses[f's{i}.{name}'] = (value * lw_cnt if 'loss' in name else value)
else:
losses[f's{i}.{name}'] = (value * lw if 'loss' in name else value)

# mask head forward and loss
if self.with_mask:
mask_results = self._mask_forward_train(
i, x, sampling_results, gt_masks, rcnn_train_cfg,
bbox_results['bbox_feats'])
for name, value in mask_results['loss_mask'].items():
losses[f's{i}.{name}'] = (
value * lw if 'loss' in name else value)
if name.endswith('_cnt'):
losses[f's{i}.{name}'] = (value * lw_cnt if 'loss' in name else value)
else:
losses[f's{i}.{name}'] = (value * lw if 'loss' in name else value)

# refine bboxes
if i < self.num_stages - 1:
Expand Down

0 comments on commit 17b1b9b

Please sign in to comment.