Skip to content

Commit

Permalink
FastRCNNOutputs support training with no images
Browse files Browse the repository at this point in the history
Reviewed By: rbgirshick

Differential Revision: D20348885

fbshipit-source-id: 06b38916a8f56ff8849754ce27b942a19759f7b5
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Mar 9, 2020
1 parent 9d969b7 commit 5cf3dc7
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 19 deletions.
50 changes: 32 additions & 18 deletions detectron2/modeling/roi_heads/fast_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def fast_rcnn_inference(boxes, scores, image_shapes, score_thresh, nms_thresh, t
)
for scores_per_image, boxes_per_image, image_shape in zip(scores, boxes, image_shapes)
]
return tuple(list(x) for x in zip(*result_per_image))
return [x[0] for x in result_per_image], [x[1] for x in result_per_image]


def fast_rcnn_inference_single_image(
Expand Down Expand Up @@ -149,6 +149,7 @@ def __init__(
proposals for image i, in the field "proposal_boxes".
When training, each Instances must have ground-truth labels
stored in the field "gt_classes" and "gt_boxes".
The total number of all instances must be equal to R.
smooth_l1_beta (float): The transition point between L1 and L2 loss in
the smooth L1 loss function. When set to 0, the loss becomes L1. When
set to +inf, the loss becomes constant 0.
Expand All @@ -159,17 +160,24 @@ def __init__(
self.pred_proposal_deltas = pred_proposal_deltas
self.smooth_l1_beta = smooth_l1_beta

box_type = type(proposals[0].proposal_boxes)
# cat(..., dim=0) concatenates over all images in the batch
self.proposals = box_type.cat([p.proposal_boxes for p in proposals])
assert not self.proposals.tensor.requires_grad, "Proposals should not require gradients!"
self.image_shapes = [x.image_size for x in proposals]

# The following fields should exist only when training.
if proposals[0].has("gt_boxes"):
self.gt_boxes = box_type.cat([p.gt_boxes for p in proposals])
assert proposals[0].has("gt_classes")
self.gt_classes = cat([p.gt_classes for p in proposals], dim=0)
if len(proposals):
box_type = type(proposals[0].proposal_boxes)
# cat(..., dim=0) concatenates over all images in the batch
self.proposals = box_type.cat([p.proposal_boxes for p in proposals])
assert (
not self.proposals.tensor.requires_grad
), "Proposals should not require gradients!"
self.image_shapes = [x.image_size for x in proposals]

# The following fields should exist only when training.
if proposals[0].has("gt_boxes"):
self.gt_boxes = box_type.cat([p.gt_boxes for p in proposals])
assert proposals[0].has("gt_classes")
self.gt_classes = cat([p.gt_classes for p in proposals], dim=0)
else:
self.proposals = Boxes(torch.zeros(0, 4, device=self.pred_proposal_deltas.device))
self.image_shapes = []
self._no_instances = len(proposals) == 0 # no instances found

def _log_accuracy(self):
"""
Expand All @@ -189,10 +197,11 @@ def _log_accuracy(self):
fg_num_accurate = (fg_pred_classes == fg_gt_classes).nonzero().numel()

storage = get_event_storage()
storage.put_scalar("fast_rcnn/cls_accuracy", num_accurate / num_instances)
if num_fg > 0:
storage.put_scalar("fast_rcnn/fg_cls_accuracy", fg_num_accurate / num_fg)
storage.put_scalar("fast_rcnn/false_negative", num_false_negative / num_fg)
if num_instances > 0:
storage.put_scalar("fast_rcnn/cls_accuracy", num_accurate / num_instances)
if num_fg > 0:
storage.put_scalar("fast_rcnn/fg_cls_accuracy", fg_num_accurate / num_fg)
storage.put_scalar("fast_rcnn/false_negative", num_false_negative / num_fg)

def softmax_cross_entropy_loss(self):
"""
Expand All @@ -201,8 +210,11 @@ def softmax_cross_entropy_loss(self):
Returns:
scalar Tensor
"""
self._log_accuracy()
return F.cross_entropy(self.pred_class_logits, self.gt_classes, reduction="mean")
if self._no_instances:
return 0.0 * self.pred_class_logits.sum()
else:
self._log_accuracy()
return F.cross_entropy(self.pred_class_logits, self.gt_classes, reduction="mean")

def smooth_l1_loss(self):
"""
Expand All @@ -211,6 +223,8 @@ def smooth_l1_loss(self):
Returns:
scalar Tensor
"""
if self._no_instances:
return 0.0 * self.pred_proposal_deltas.sum()
gt_proposal_deltas = self.box2box_transform.get_deltas(
self.proposals.tensor, self.gt_boxes.tensor
)
Expand Down
2 changes: 1 addition & 1 deletion detectron2/modeling/roi_heads/rotated_fast_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def fast_rcnn_inference_rotated(
)
for scores_per_image, boxes_per_image, image_shape in zip(scores, boxes, image_shapes)
]
return tuple(list(x) for x in zip(*result_per_image))
return [x[0] for x in result_per_image], [x[1] for x in result_per_image]


def fast_rcnn_inference_single_image_rotated(
Expand Down
18 changes: 18 additions & 0 deletions tests/test_fast_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,24 @@ def test_fast_rcnn(self):
for name in expected_losses.keys():
assert torch.allclose(losses[name], expected_losses[name])

def test_fast_rcnn_empty_batch(self):
cfg = get_cfg()
cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS = (10, 10, 5, 5)
box2box_transform = Box2BoxTransform(weights=cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS)

logits = torch.randn(0, 100, requires_grad=True)
deltas = torch.randn(0, 4, requires_grad=True)
outputs = FastRCNNOutputs(box2box_transform, logits, deltas, [], 0.5)
losses = outputs.losses()
for value in losses.values():
self.assertTrue(torch.allclose(value, torch.zeros_like(value)))
sum(losses.values()).backward()
self.assertTrue(logits.grad is not None)
self.assertTrue(deltas.grad is not None)

predictions, _ = outputs.inference(0.05, 0.5, 100)
self.assertEqual(len(predictions), 0)

def test_fast_rcnn_rotated(self):
torch.manual_seed(132)
cfg = get_cfg()
Expand Down

0 comments on commit 5cf3dc7

Please sign in to comment.