Skip to content

Commit

Permalink
Add focal gamma param
Browse files Browse the repository at this point in the history
  • Loading branch information
timmeinhardt committed Apr 28, 2022
1 parent 6feed95 commit ff7f233
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 14 deletions.
1 change: 1 addition & 0 deletions cfgs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ giou_loss_coef: 2
eos_coef: 0.1
focal_loss: false
focal_alpha: 0.25
focal_gamma: 2
# Dataset
dataset: coco
train_split: train
Expand Down
1 change: 1 addition & 0 deletions cfgs/train_focal_loss.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
focal_loss: true
focal_alpha: 0.25
focal_gamma: 2
cls_loss_coef: 2.0
set_cost_class: 2.0
2 changes: 2 additions & 0 deletions src/trackformer/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def evaluate(model, criterion, postprocessors, data_loader, device,
results_orig = {
target['image_id'].item(): output
for target, output in zip(targets, results_orig)}

coco_evaluator.update(results_orig)

if panoptic_evaluator is not None:
Expand Down Expand Up @@ -333,6 +334,7 @@ def evaluate(model, criterion, postprocessors, data_loader, device,
if visualizers:
vis_epoch = visualizers['epoch_metrics']
y_data = [stats[legend_name] for legend_name in vis_epoch.viz_opts['legend']]

vis_epoch.plot(y_data, epoch)

visualizers['epoch_eval'].plot(eval_stats, epoch)
Expand Down
1 change: 1 addition & 0 deletions src/trackformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def build_model(args):
losses=losses,
focal_loss=args.focal_loss,
focal_alpha=args.focal_alpha,
focal_gamma=args.focal_gamma,
tracking=args.tracking,
track_query_false_positive_eos_weight=args.track_query_false_positive_eos_weight,)
criterion.to(device)
Expand Down
4 changes: 1 addition & 3 deletions src/trackformer/models/deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
from torch import nn

from ..util import box_ops
from ..util.misc import (NestedTensor, accuracy, get_world_size,
inverse_sigmoid, is_dist_avail_and_initialized,
nested_tensor_from_tensor_list, sigmoid_focal_loss)
from ..util.misc import NestedTensor, inverse_sigmoid, nested_tensor_from_tensor_list
from .detr import DETR, PostProcess, SetCriterion


Expand Down
45 changes: 43 additions & 2 deletions src/trackformer/models/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class SetCriterion(nn.Module):
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses,
focal_loss, focal_alpha, tracking, track_query_false_positive_eos_weight):
focal_loss, focal_alpha, focal_gamma, tracking, track_query_false_positive_eos_weight):
""" Create the criterion.
Parameters:
num_classes: number of object categories, omitting the special no-object category
Expand All @@ -165,6 +165,7 @@ def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses,
self.register_buffer('empty_weight', empty_weight)
self.focal_loss = focal_loss
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
self.tracking = tracking
self.track_query_false_positive_eos_weight = track_query_false_positive_eos_weight

Expand Down Expand Up @@ -233,7 +234,9 @@ def loss_labels_focal(self, outputs, targets, indices, num_boxes, log=True):
query_mask = torch.stack([~t['track_queries_placeholder_mask'] for t in targets])[..., None]
query_mask = query_mask.repeat(1, 1, self.num_classes)

loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2, query_mask=query_mask)
loss_ce = sigmoid_focal_loss(
src_logits, target_classes_onehot, num_boxes,
alpha=self.focal_alpha, gamma=self.focal_gamma, query_mask=query_mask)

if self.tracking:
mean_num_queries = torch.tensor([len(m.nonzero()) for m in query_mask]).float().mean()
Expand All @@ -245,6 +248,26 @@ def loss_labels_focal(self, outputs, targets, indices, num_boxes, log=True):
if log:
# TODO this should probably be a separate loss, not hacked in this one here
losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]

# compute seperate track and object query losses
# loss_ce = sigmoid_focal_loss(
# src_logits, target_classes_onehot, num_boxes,
# alpha=self.focal_alpha, gamma=self.focal_gamma, query_mask=query_mask, reduction=False)
# loss_ce *= src_logits.shape[1]

# track_query_target_masks = []
# for t, ind in zip(targets, indices):
# track_query_target_mask = torch.zeros_like(ind[1]).bool()

# for i in t['track_query_match_ids']:
# track_query_target_mask[ind[1].eq(i).nonzero()[0]] = True

# track_query_target_masks.append(track_query_target_mask)
# track_query_target_masks = torch.cat(track_query_target_masks)

# losses['loss_ce_track_queries'] = loss_ce[idx][track_query_target_masks].mean(1).sum() / num_boxes
# losses['loss_ce_object_queries'] = loss_ce[idx][~track_query_target_masks].mean(1).sum() / num_boxes

return losses

@torch.no_grad()
Expand Down Expand Up @@ -282,6 +305,24 @@ def loss_boxes(self, outputs, targets, indices, num_boxes):
box_ops.box_cxcywh_to_xyxy(src_boxes),
box_ops.box_cxcywh_to_xyxy(target_boxes)))
losses['loss_giou'] = loss_giou.sum() / num_boxes

# compute seperate track and object query losses
# track_query_target_masks = []
# for t, ind in zip(targets, indices):
# track_query_target_mask = torch.zeros_like(ind[1]).bool()

# for i in t['track_query_match_ids']:
# track_query_target_mask[ind[1].eq(i).nonzero()[0]] = True

# track_query_target_masks.append(track_query_target_mask)
# track_query_target_masks = torch.cat(track_query_target_masks)

# losses['loss_bbox_track_queries'] = loss_bbox[track_query_target_masks].sum() / num_boxes
# losses['loss_bbox_object_queries'] = loss_bbox[~track_query_target_masks].sum() / num_boxes

# losses['loss_giou_track_queries'] = loss_giou[track_query_target_masks].sum() / num_boxes
# losses['loss_giou_object_queries'] = loss_giou[~track_query_target_masks].sum() / num_boxes

return losses

def loss_masks(self, outputs, targets, indices, num_boxes):
Expand Down
5 changes: 3 additions & 2 deletions src/trackformer/models/detr_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,9 @@ def forward(self, samples: NestedTensor, targets: list = None, prev_features=Non
if targets is not None and not self._tracking:
prev_targets = [target['prev_target'] for target in targets]

if self.training: # and random.uniform(0, 1) < 0.5:
# if self.training:
# if self.training and random.uniform(0, 1) < 0.5:
if self.training:
# if True:
backprop_context = torch.no_grad
if self._backprop_prev_frame:
backprop_context = nullcontext
Expand Down
11 changes: 6 additions & 5 deletions src/trackformer/models/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class HungarianMatcher(nn.Module):
"""

def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1,
focal_loss: bool = False, focal_alpha: float = 0.25):
focal_loss: bool = False, focal_alpha: float = 0.25, focal_gamma: float = 2.0):
"""Creates the matcher
Params:
Expand All @@ -35,6 +35,7 @@ def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float
self.cost_giou = cost_giou
self.focal_loss = focal_loss
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"

@torch.no_grad()
Expand Down Expand Up @@ -80,9 +81,8 @@ def forward(self, outputs, targets):

# Compute the classification cost.
if self.focal_loss:
gamma = 2.0
neg_cost_class = (1 - self.focal_alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
pos_cost_class = self.focal_alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
neg_cost_class = (1 - self.focal_alpha) * (out_prob ** self.focal_gamma) * (-(1 - out_prob + 1e-8).log())
pos_cost_class = self.focal_alpha * ((1 - out_prob) ** self.focal_gamma) * (-(out_prob + 1e-8).log())
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
else:
# Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class].
Expand Down Expand Up @@ -136,4 +136,5 @@ def build_matcher(args):
cost_bbox=args.set_cost_bbox,
cost_giou=args.set_cost_giou,
focal_loss=args.focal_loss,
focal_alpha=args.focal_alpha,)
focal_alpha=args.focal_alpha,
focal_gamma=args.focal_gamma,)
5 changes: 4 additions & 1 deletion src/trackformer/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def dice_loss(inputs, targets, num_boxes):
return loss.sum() / num_boxes


def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, query_mask=None):
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, query_mask=None, reduction=True):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
Expand All @@ -561,6 +561,9 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss

if not reduction:
return loss

if query_mask is not None:
loss = torch.stack([l[m].mean(0) for l, m in zip(loss, query_mask)])
return loss.sum() / num_boxes
Expand Down
2 changes: 1 addition & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def match_name_keywords(n, name_keywords):
if args.eval_only:
_, coco_evaluator = evaluate(
model, criterion, postprocessors, data_loader_val, device,
output_dir, visualizers['val'], args)
output_dir, visualizers['val'], args, 0)
if args.output_dir:
utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")

Expand Down

0 comments on commit ff7f233

Please sign in to comment.