Skip to content

Commit

Permalink
fix(models): no-candidate anchor issue for tiny objects during label …
Browse files Browse the repository at this point in the history
…assign (Megvii-BaseDetection#1589)

fix(models): no-candidate anchor issue for tiny objects during label assign
  • Loading branch information
Joker316701882 authored Jan 9, 2023
1 parent a4152a5 commit 14c62a7
Showing 1 changed file with 43 additions and 109 deletions.
152 changes: 43 additions & 109 deletions yolox/models/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(
"""
super().__init__()

self.n_anchors = 1
self.num_classes = num_classes
self.decode_in_inference = True # for deploy, set to False

Expand Down Expand Up @@ -97,7 +96,7 @@ def __init__(
self.cls_preds.append(
nn.Conv2d(
in_channels=int(256 * width),
out_channels=self.n_anchors * self.num_classes,
out_channels=self.num_classes,
kernel_size=1,
stride=1,
padding=0,
Expand All @@ -115,7 +114,7 @@ def __init__(
self.obj_preds.append(
nn.Conv2d(
in_channels=int(256 * width),
out_channels=self.n_anchors * 1,
out_channels=1,
kernel_size=1,
stride=1,
padding=0,
Expand All @@ -131,12 +130,12 @@ def __init__(

def initialize_biases(self, prior_prob):
for conv in self.cls_preds:
b = conv.bias.view(self.n_anchors, -1)
b = conv.bias.view(1, -1)
b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)

for conv in self.obj_preds:
b = conv.bias.view(self.n_anchors, -1)
b = conv.bias.view(1, -1)
b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)

Expand Down Expand Up @@ -177,7 +176,7 @@ def forward(self, xin, labels=None, imgs=None):
batch_size = reg_output.shape[0]
hsize, wsize = reg_output.shape[-2:]
reg_output = reg_output.view(
batch_size, self.n_anchors, 4, hsize, wsize
batch_size, 1, 4, hsize, wsize
)
reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(
batch_size, -1, 4
Expand Down Expand Up @@ -224,9 +223,9 @@ def get_output_and_grid(self, output, k, stride, dtype):
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
self.grids[k] = grid

output = output.view(batch_size, self.n_anchors, n_ch, hsize, wsize)
output = output.view(batch_size, 1, n_ch, hsize, wsize)
output = output.permute(0, 1, 3, 4, 2).reshape(
batch_size, self.n_anchors * hsize * wsize, -1
batch_size, hsize * wsize, -1
)
grid = grid.view(1, -1, 2)
output[..., :2] = (output[..., :2] + grid) * stride
Expand Down Expand Up @@ -265,7 +264,7 @@ def get_losses(
dtype,
):
bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
obj_preds = outputs[:, :, 4].unsqueeze(-1) # [batch, n_anchors_all, 1]
obj_preds = outputs[:, :, 4:5] # [batch, n_anchors_all, 1]
cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]

# calculate targets
Expand Down Expand Up @@ -311,18 +310,14 @@ def get_losses(
) = self.get_assignments( # noqa
batch_idx,
num_gt,
total_num_anchors,
gt_bboxes_per_image,
gt_classes,
bboxes_preds_per_image,
expanded_strides,
x_shifts,
y_shifts,
cls_preds,
bbox_preds,
obj_preds,
labels,
imgs,
)
except RuntimeError as e:
# TODO: the string might change, consider a better way
Expand All @@ -344,18 +339,14 @@ def get_losses(
) = self.get_assignments( # noqa
batch_idx,
num_gt,
total_num_anchors,
gt_bboxes_per_image,
gt_classes,
bboxes_preds_per_image,
expanded_strides,
x_shifts,
y_shifts,
cls_preds,
bbox_preds,
obj_preds,
labels,
imgs,
"cpu",
)

Expand Down Expand Up @@ -433,37 +424,31 @@ def get_assignments(
self,
batch_idx,
num_gt,
total_num_anchors,
gt_bboxes_per_image,
gt_classes,
bboxes_preds_per_image,
expanded_strides,
x_shifts,
y_shifts,
cls_preds,
bbox_preds,
obj_preds,
labels,
imgs,
mode="gpu",
):

if mode == "cpu":
print("------------CPU Mode for This Batch-------------")
print("-----------Using CPU for the Current Batch-------------")
gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
gt_classes = gt_classes.cpu().float()
expanded_strides = expanded_strides.cpu().float()
x_shifts = x_shifts.cpu()
y_shifts = y_shifts.cpu()

fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
fg_mask, geometry_relation = self.get_geometry_constraint(
gt_bboxes_per_image,
expanded_strides,
x_shifts,
y_shifts,
total_num_anchors,
num_gt,
)

bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
Expand All @@ -480,8 +465,6 @@ def get_assignments(
gt_cls_per_image = (
F.one_hot(gt_classes.to(torch.int64), self.num_classes)
.float()
.unsqueeze(1)
.repeat(1, num_in_boxes_anchor, 1)
)
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)

Expand All @@ -490,26 +473,27 @@ def get_assignments(

with torch.cuda.amp.autocast(enabled=False):
cls_preds_ = (
cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
* obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
)
cls_preds_.float().sigmoid_() * obj_preds_.float().sigmoid_()
).sqrt()
pair_wise_cls_loss = F.binary_cross_entropy(
cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
cls_preds_.unsqueeze(0).repeat(num_gt, 1, 1),
gt_cls_per_image.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1),
reduction="none"
).sum(-1)
del cls_preds_

cost = (
pair_wise_cls_loss
+ 3.0 * pair_wise_ious_loss
+ 100000.0 * (~is_in_boxes_and_center)
+ float(1e6) * (~geometry_relation)
)

(
num_fg,
gt_matched_classes,
pred_ious_this_matching,
matched_gt_inds,
) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
) = self.simota_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss

if mode == "cpu":
Expand All @@ -526,101 +510,49 @@ def get_assignments(
num_fg,
)

def get_in_boxes_info(
def get_geometry_constraint(
self,
gt_bboxes_per_image,
expanded_strides,
x_shifts,
y_shifts,
total_num_anchors,
num_gt,
):
"""
Calculate whether the center of an object is located in a fixed range of
an anchor. This is used to avert inappropriate matching. It can also reduce
the number of candidate anchors so that the GPU memory is saved.
"""
expanded_strides_per_image = expanded_strides[0]
x_shifts_per_image = x_shifts[0] * expanded_strides_per_image
y_shifts_per_image = y_shifts[0] * expanded_strides_per_image
x_centers_per_image = (
(x_shifts_per_image + 0.5 * expanded_strides_per_image)
.unsqueeze(0)
.repeat(num_gt, 1)
) # [n_anchor] -> [n_gt, n_anchor]
y_centers_per_image = (
(y_shifts_per_image + 0.5 * expanded_strides_per_image)
.unsqueeze(0)
.repeat(num_gt, 1)
)

gt_bboxes_per_image_l = (
(gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2])
.unsqueeze(1)
.repeat(1, total_num_anchors)
)
gt_bboxes_per_image_r = (
(gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2])
.unsqueeze(1)
.repeat(1, total_num_anchors)
)
gt_bboxes_per_image_t = (
(gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3])
.unsqueeze(1)
.repeat(1, total_num_anchors)
)
gt_bboxes_per_image_b = (
(gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3])
.unsqueeze(1)
.repeat(1, total_num_anchors)
)
x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)
y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)

b_l = x_centers_per_image - gt_bboxes_per_image_l
b_r = gt_bboxes_per_image_r - x_centers_per_image
b_t = y_centers_per_image - gt_bboxes_per_image_t
b_b = gt_bboxes_per_image_b - y_centers_per_image
bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)

is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
# in fixed center

center_radius = 2.5

gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
1, total_num_anchors
) - center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
1, total_num_anchors
) + center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
1, total_num_anchors
) - center_radius * expanded_strides_per_image.unsqueeze(0)
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
1, total_num_anchors
) + center_radius * expanded_strides_per_image.unsqueeze(0)
center_radius = 1.5
center_dist = expanded_strides_per_image.unsqueeze(0) * center_radius
gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0:1]) - center_dist
gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0:1]) + center_dist
gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1:2]) - center_dist
gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1:2]) + center_dist

c_l = x_centers_per_image - gt_bboxes_per_image_l
c_r = gt_bboxes_per_image_r - x_centers_per_image
c_t = y_centers_per_image - gt_bboxes_per_image_t
c_b = gt_bboxes_per_image_b - y_centers_per_image
center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
is_in_centers = center_deltas.min(dim=-1).values > 0.0
is_in_centers_all = is_in_centers.sum(dim=0) > 0
anchor_filter = is_in_centers.sum(dim=0) > 0
geometry_relation = is_in_centers[:, anchor_filter]

# in boxes and in centers
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all

is_in_boxes_and_center = (
is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
)
return is_in_boxes_anchor, is_in_boxes_and_center
return anchor_filter, geometry_relation

def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
# Dynamic K
# ---------------------------------------------------------------
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)

ious_in_boxes_matrix = pair_wise_ious
n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
n_candidate_k = min(10, pair_wise_ious.size(1))
topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
dynamic_ks = dynamic_ks.tolist()
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(
cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
Expand All @@ -630,11 +562,13 @@ def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
del topk_ious, dynamic_ks, pos_idx

anchor_matching_gt = matching_matrix.sum(0)
# deal with the case that one anchor matches multiple ground-truths
if anchor_matching_gt.max() > 1:
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
matching_matrix[:, anchor_matching_gt > 1] *= 0
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
fg_mask_inboxes = matching_matrix.sum(0) > 0
multiple_match_mask = anchor_matching_gt > 1
_, cost_argmin = torch.min(cost[:, multiple_match_mask], dim=0)
matching_matrix[:, multiple_match_mask] *= 0
matching_matrix[cost_argmin, multiple_match_mask] = 1
fg_mask_inboxes = anchor_matching_gt > 0
num_fg = fg_mask_inboxes.sum().item()

fg_mask[fg_mask.clone()] = fg_mask_inboxes
Expand Down

0 comments on commit 14c62a7

Please sign in to comment.