Skip to content

Commit

Permalink
update loss
Browse files Browse the repository at this point in the history
  • Loading branch information
bubbliiiing committed Apr 3, 2022
1 parent f11823b commit fd8ee24
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions nets/yolo_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,13 @@ def forward(self, l, input, targets=None):
#-----------------------------------------------------------#
# 计算中心偏移情况的loss,使用BCELoss效果好一些
#-----------------------------------------------------------#
loss_x = torch.mean(self.BCELoss(x[obj_mask], y_true[..., 0][obj_mask]) * box_loss_scale)
loss_y = torch.mean(self.BCELoss(y[obj_mask], y_true[..., 1][obj_mask]) * box_loss_scale)
loss_x = torch.mean(self.BCELoss(x[obj_mask], y_true[..., 0][obj_mask]) * box_loss_scale[obj_mask])
loss_y = torch.mean(self.BCELoss(y[obj_mask], y_true[..., 1][obj_mask]) * box_loss_scale[obj_mask])
#-----------------------------------------------------------#
# 计算宽高调整值的loss
#-----------------------------------------------------------#
loss_w = torch.mean(self.MSELoss(w[obj_mask], y_true[..., 2][obj_mask]) * box_loss_scale)
loss_h = torch.mean(self.MSELoss(h[obj_mask], y_true[..., 3][obj_mask]) * box_loss_scale)
loss_w = torch.mean(self.MSELoss(w[obj_mask], y_true[..., 2][obj_mask]) * box_loss_scale[obj_mask])
loss_h = torch.mean(self.MSELoss(h[obj_mask], y_true[..., 3][obj_mask]) * box_loss_scale[obj_mask])
loss_loc = (loss_x + loss_y + loss_h + loss_w) * 0.1

loss_cls = torch.mean(self.BCELoss(pred_cls[obj_mask], y_true[..., 5:][obj_mask]))
Expand Down

0 comments on commit fd8ee24

Please sign in to comment.