Skip to content

Commit

Permalink
use EMA normalizer in RetinaNet
Browse files Browse the repository at this point in the history
Summary: fix facebookresearch#868

Reviewed By: rbgirshick, alexander-kirillov

Differential Revision: D20062149

fbshipit-source-id: 5fcf0537730f3b3f7217fde42f8b97e506eab316
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Feb 24, 2020
1 parent 037823e commit d362da6
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions detectron2/modeling/meta_arch/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from detectron2.layers import ShapeSpec, batched_nms, cat
from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
from detectron2.utils.events import get_event_storage
from detectron2.utils.logger import log_first_n

from ..anchor_generator import build_anchor_generator
Expand Down Expand Up @@ -98,6 +99,15 @@ def __init__(self, cfg):
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
self.to(self.device)

"""
In Detectron1, loss is normalized by number of foreground samples in the batch.
When batch size is 1 per GPU, #foreground has a large variance and
using it lead to lower performance. Here we maintain an EMA of #foreground to
stabilize the normalizer.
"""
self.loss_normalizer = 100 # initialize with any reasonable #fg that's not too small
self.loss_normalizer_momentum = 0.9

def forward(self, batched_inputs):
"""
Args:
Expand Down Expand Up @@ -172,7 +182,12 @@ def losses(self, gt_classes, gt_anchors_deltas, pred_class_logits, pred_anchor_d

valid_idxs = gt_classes >= 0
foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
num_foreground = foreground_idxs.sum()
num_foreground = foreground_idxs.sum().item()
get_event_storage().put_scalar("num_foreground", num_foreground)
self.loss_normalizer = (
self.loss_normalizer_momentum * self.loss_normalizer
+ (1 - self.loss_normalizer_momentum) * num_foreground
)

gt_classes_target = torch.zeros_like(pred_class_logits)
gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1
Expand All @@ -184,15 +199,15 @@ def losses(self, gt_classes, gt_anchors_deltas, pred_class_logits, pred_anchor_d
alpha=self.focal_loss_alpha,
gamma=self.focal_loss_gamma,
reduction="sum",
) / max(1, num_foreground)
) / max(1, self.loss_normalizer)

# regression loss
loss_box_reg = smooth_l1_loss(
pred_anchor_deltas[foreground_idxs],
gt_anchors_deltas[foreground_idxs],
beta=self.smooth_l1_loss_beta,
reduction="sum",
) / max(1, num_foreground)
) / max(1, self.loss_normalizer)

return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}

Expand Down

0 comments on commit d362da6

Please sign in to comment.