Skip to content

Commit

Permalink
update d2 version
Browse files Browse the repository at this point in the history
  • Loading branch information
xingyizhou committed Apr 26, 2022
1 parent 1799592 commit 88dcdd8
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 202 deletions.
269 changes: 86 additions & 183 deletions unidet/modeling/roi_heads/custom_fast_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,168 +13,10 @@
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.structures import Boxes, Instances
from detectron2.utils.events import get_event_storage
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers, FastRCNNOutputs
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
from detectron2.modeling.roi_heads.fast_rcnn import _log_classification_stats

__all__ = ["CustomFastRCNNOutputLayers", "CustomFastRCNNOutputs"]


class CustomFastRCNNOutputs(FastRCNNOutputs):
def __init__(
self,
cfg,
box2box_transform,
pred_class_logits,
pred_proposal_deltas,
proposals,
smooth_l1_beta=0.0,
box_reg_loss_type="smooth_l1",
freq_weight=None,
hierarchy_weight=None
):
super().__init__(box2box_transform, pred_class_logits,
pred_proposal_deltas, proposals, smooth_l1_beta, box_reg_loss_type)
self._no_instances = (self.pred_class_logits.numel() == 0) or (len(proposals) == 0)
if self._no_instances:
print('No instances!', pred_class_logits.shape, pred_proposal_deltas.shape, len(proposals))
self.box_batch_size = cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE * len(proposals)
self.fix_norm_reg = cfg.MODEL.ROI_BOX_HEAD.FIX_NORM_REG
self.use_sigmoid_ce = cfg.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE
self.use_eql_loss = cfg.MODEL.ROI_BOX_HEAD.USE_EQL_LOSS
self.use_fed_loss = cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS
self.fed_loss_num_cat = cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT
self.pos_parents = cfg.MODEL.ROI_BOX_HEAD.HIERARCHY_POS_PARENTS
self.hierarchy_ignore = cfg.MODEL.ROI_BOX_HEAD.HIERARCHY_IGNORE
if self.pos_parents and (hierarchy_weight is not None):
self.hierarchy_weight = hierarchy_weight[0] # (C + 1) x C
self.is_parents = hierarchy_weight[1]
else:
self.hierarchy_weight = hierarchy_weight # (C + 1) x C
self.freq_weight = freq_weight

def sigmoid_cross_entropy_loss(self):
if self._no_instances:
return self.pred_class_logits.sum() * 0.
self._log_accuracy()

B = self.pred_class_logits.shape[0]
C = self.pred_class_logits.shape[1] - 1

target = self.pred_class_logits.new_zeros(B, C + 1)
target[range(len(self.gt_classes)), self.gt_classes] = 1 # B x (C + 1)
target = target[:, :C] # B x C

weight = 1
if (self.freq_weight is not None) and self.use_eql_loss: # eql loss
exclude_weight = (self.gt_classes != C).float().view(B, 1).expand(B, C)
threshold_weight = self.freq_weight.view(1, C).expand(B, C)
eql_w = 1 - exclude_weight * threshold_weight * (1 - target) # B x C
weight = weight * eql_w

if (self.freq_weight is not None) and self.use_fed_loss: # fedloss
appeared = torch.unique(self.gt_classes) # C'
prob = appeared.new_ones(C + 1).float()
if len(appeared) < self.fed_loss_num_cat:
if self.fed_loss_freq_weight > 0:
prob[:C] = self.freq_weight.float().clone()
else:
prob[:C] = prob[:C] * (1 - self.freq_weight)
prob[appeared] = 0
more_appeared = torch.multinomial(
prob, self.fed_loss_num_cat - len(appeared),
replacement=False)
appeared = torch.cat([appeared, more_appeared])
appeared_mask = appeared.new_zeros(C + 1)
appeared_mask[appeared] = 1 # C + 1
appeared_mask = appeared_mask[:C]
fed_w = appeared_mask.view(1, C).expand(B, C)
weight = weight * fed_w

if (self.hierarchy_weight is not None) and self.hierarchy_ignore:
if self.pos_parents:
target = torch.mm(target, self.is_parents) + target # B x C
hierarchy_w = self.hierarchy_weight[self.gt_classes] # B x C
weight = weight * hierarchy_w

cls_loss = F.binary_cross_entropy_with_logits(
self.pred_class_logits[:, :-1], target, reduction='none') # B x C
return torch.sum(cls_loss * weight) / B

def softmax_cross_entropy_loss(self):
"""
change _no_instance handling
"""
if self._no_instances:
return self.pred_class_logits.sum() * 0.
else:
self._log_accuracy()
return F.cross_entropy(self.pred_class_logits, self.gt_classes, reduction="mean")


def box_reg_loss(self):
"""
change _no_instance handling and normalization
"""
if self._no_instances:
print('No instance in box reg loss')
return self.pred_proposal_deltas.sum() * 0.

box_dim = self.gt_boxes.tensor.size(1) # 4 or 5
cls_agnostic_bbox_reg = self.pred_proposal_deltas.size(1) == box_dim
device = self.pred_proposal_deltas.device

bg_class_ind = self.pred_class_logits.shape[1] - 1

fg_inds = nonzero_tuple((self.gt_classes >= 0) & (self.gt_classes < bg_class_ind))[0]
if cls_agnostic_bbox_reg:
gt_class_cols = torch.arange(box_dim, device=device)
else:
fg_gt_classes = self.gt_classes[fg_inds]
gt_class_cols = box_dim * fg_gt_classes[:, None] + torch.arange(box_dim, device=device)

if self.box_reg_loss_type == "smooth_l1":
gt_proposal_deltas = self.box2box_transform.get_deltas(
self.proposals.tensor, self.gt_boxes.tensor
)
loss_box_reg = smooth_l1_loss(
self.pred_proposal_deltas[fg_inds[:, None], gt_class_cols],
gt_proposal_deltas[fg_inds],
self.smooth_l1_beta,
reduction="sum",
)
elif self.box_reg_loss_type == "giou":
loss_box_reg = giou_loss(
self._predict_boxes()[fg_inds[:, None], gt_class_cols],
self.gt_boxes.tensor[fg_inds],
reduction="sum",
)
else:
raise ValueError(f"Invalid bbox reg loss type '{self.box_reg_loss_type}'")

if self.fix_norm_reg:
loss_box_reg = loss_box_reg / self.box_batch_size
else:
loss_box_reg = loss_box_reg / self.gt_classes.numel()
return loss_box_reg

def losses(self):
if self.use_sigmoid_ce:
loss_cls = self.sigmoid_cross_entropy_loss()
else:
loss_cls = self.softmax_cross_entropy_loss()
return {
"loss_cls": loss_cls,
"loss_box_reg": self.box_reg_loss()
}

def predict_probs(self):
"""
Deprecated
"""
if self.use_sigmoid_ce:
probs = F.sigmoid(self.pred_class_logits)
else:
probs = F.softmax(self.pred_class_logits, dim=-1)
return probs.split(self.num_preds_per_image, dim=0)
__all__ = ["CustomFastRCNNOutputLayers"]


def _load_class_freq(cfg):
Expand Down Expand Up @@ -239,15 +81,24 @@ def __init__(
):
super().__init__(cfg, input_shape, **kwargs)
self.use_sigmoid_ce = cfg.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE
self.use_eql_loss = cfg.MODEL.ROI_BOX_HEAD.USE_EQL_LOSS
self.use_fed_loss = cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS
self.fed_loss_num_cat = cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT
self.pos_parents = cfg.MODEL.ROI_BOX_HEAD.HIERARCHY_POS_PARENTS
self.hierarchy_ignore = cfg.MODEL.ROI_BOX_HEAD.HIERARCHY_IGNORE

if self.use_sigmoid_ce:
prior_prob = cfg.MODEL.ROI_BOX_HEAD.PRIOR_PROB
bias_value = -math.log((1 - prior_prob) / prior_prob)
nn.init.constant_(self.cls_score.bias, bias_value)

self.cfg = cfg
self.freq_weight = _load_class_freq(cfg)
self.hierarchy_weight = _load_class_hierarchy(cfg)
hierarchy_weight = _load_class_hierarchy(cfg)
if self.pos_parents and (hierarchy_weight is not None):
self.hierarchy_weight = hierarchy_weight[0] # (C + 1) x C
self.is_parents = hierarchy_weight[1]
else:
self.hierarchy_weight = hierarchy_weight # (C + 1) x C


def predict_probs(self, predictions, proposals):
Expand All @@ -261,29 +112,81 @@ def predict_probs(self, predictions, proposals):
return probs.split(num_inst_per_image, dim=0)


def sigmoid_cross_entropy_loss(
self, pred_class_logits, gt_classes, use_advanced_loss=True):
if pred_class_logits.numel() == 0:
return pred_class_logits.new_zeros([1])[0] # This is more robust than .sum() * 0.

B = self.pred_class_logits.shape[0]
C = self.pred_class_logits.shape[1] - 1

target = self.pred_class_logits.new_zeros(B, C + 1)
target[range(len(gt_classes)), gt_classes] = 1 # B x (C + 1)
target = target[:, :C] # B x C

weight = 1
if use_advanced_loss and (self.freq_weight is not None) and \
self.use_fed_loss: # fedloss
appeared = torch.unique(gt_classes) # C'
prob = appeared.new_ones(C + 1).float()
if len(appeared) < self.fed_loss_num_cat:
if self.fed_loss_freq_weight > 0:
prob[:C] = self.freq_weight.float().clone()
else:
prob[:C] = prob[:C] * (1 - self.freq_weight)
prob[appeared] = 0
more_appeared = torch.multinomial(
prob, self.fed_loss_num_cat - len(appeared),
replacement=False)
appeared = torch.cat([appeared, more_appeared])
appeared_mask = appeared.new_zeros(C + 1)
appeared_mask[appeared] = 1 # C + 1
appeared_mask = appeared_mask[:C]
fed_w = appeared_mask.view(1, C).expand(B, C)
weight = weight * fed_w

if use_advanced_loss and (self.hierarchy_weight is not None) and \
self.hierarchy_ignore:
if self.pos_parents:
target = torch.mm(target, self.is_parents) + target # B x C
hierarchy_w = self.hierarchy_weight[gt_classes] # B x C
weight = weight * hierarchy_w

cls_loss = F.binary_cross_entropy_with_logits(
self.pred_class_logits[:, :-1], target, reduction='none') # B x C
return torch.sum(cls_loss * weight) / B


def losses(self, predictions, proposals, use_advanced_loss=True):
"""
Args:
predictions: return values of :meth:`forward()`.
proposals (list[Instances]): proposals that match the features that were used
to compute predictions. The fields ``proposal_boxes``, ``gt_boxes``,
``gt_classes`` are expected.
Returns:
Dict[str, Tensor]: dict of losses
enable advanced loss
"""
scores, proposal_deltas = predictions
losses = CustomFastRCNNOutputs(
self.cfg,
self.box2box_transform,
scores,
proposal_deltas,
proposals,
self.smooth_l1_beta,
self.box_reg_loss_type,
self.freq_weight if use_advanced_loss else None,
self.hierarchy_weight if use_advanced_loss else None,
).losses()
return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}
gt_classes = (
cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0)
)
_log_classification_stats(scores, gt_classes)


if len(proposals):
proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4
assert not proposal_boxes.requires_grad, "Proposals should not require gradients!"
gt_boxes = cat(
[(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals],
dim=0,
)
else:
proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device)


if self.use_sigmoid_ce:
loss_cls = self.sigmoid_cross_entropy_loss(
scores, gt_classes, use_advanced_loss)
else:
assert not use_advanced_loss
loss_cls = self.softmax_cross_entropy_loss(scores, gt_classes)
return {
"loss_cls": loss_cls,
"loss_box_reg": self.box_reg_loss(
proposal_boxes, gt_boxes, proposal_deltas, gt_classes)
}
47 changes: 28 additions & 19 deletions unidet/modeling/roi_heads/multi_dataset_fast_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@
import math
from typing import Dict, Union
import torch
from fvcore.nn import giou_loss, smooth_l1_loss
from torch import nn
from torch.nn import functional as F

from detectron2.config import configurable
from detectron2.layers import Linear, ShapeSpec, batched_nms, cat, nonzero_tuple
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.structures import Boxes, Instances
from detectron2.utils.events import get_event_storage
from .custom_fast_rcnn import CustomFastRCNNOutputLayers, CustomFastRCNNOutputs
from detectron2.modeling.roi_heads.fast_rcnn import _log_classification_stats
from .custom_fast_rcnn import CustomFastRCNNOutputLayers

class MultiDatasetFastRCNNOutputLayers(CustomFastRCNNOutputLayers):
def __init__(
Expand Down Expand Up @@ -48,18 +44,31 @@ def forward(self, x, dataset_source=-1):
return scores, proposal_deltas

def losses(self, predictions, proposals, dataset_source):
is_open_image = (dataset_source == self.openimage_index)
use_advanced_loss = (dataset_source == self.openimage_index)
scores, proposal_deltas = predictions
losses = CustomFastRCNNOutputs(
self.cfg,
self.box2box_transform,
scores,
proposal_deltas,
proposals,
self.smooth_l1_beta,
self.box_reg_loss_type,
self.freq_weight if is_open_image else None,
self.hierarchy_weight if is_open_image else None,
).losses()
return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}
gt_classes = (
cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0)
)
_log_classification_stats(scores, gt_classes)

if len(proposals):
proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4
assert not proposal_boxes.requires_grad, "Proposals should not require gradients!"
gt_boxes = cat(
[(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals],
dim=0,
)
else:
proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device)

if self.use_sigmoid_ce:
loss_cls = self.sigmoid_cross_entropy_loss(
scores, gt_classes, use_advanced_loss)
else:
assert not use_advanced_loss
loss_cls = self.softmax_cross_entropy_loss(scores, gt_classes)
return {
"loss_cls": loss_cls,
"loss_box_reg": self.box_reg_loss(
proposal_boxes, gt_boxes, proposal_deltas, gt_classes)
}

0 comments on commit 88dcdd8

Please sign in to comment.