From fd8ee24f688ae0f706413f1721f53e1beb464f40 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <3323290568@qq.com> Date: Sun, 3 Apr 2022 12:00:42 +0800 Subject: [PATCH] update loss --- nets/yolo_training.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nets/yolo_training.py b/nets/yolo_training.py index 7fa4254..be03c65 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -190,13 +190,13 @@ def forward(self, l, input, targets=None): #-----------------------------------------------------------# # 计算中心偏移情况的loss,使用BCELoss效果好一些 #-----------------------------------------------------------# - loss_x = torch.mean(self.BCELoss(x[obj_mask], y_true[..., 0][obj_mask]) * box_loss_scale) - loss_y = torch.mean(self.BCELoss(y[obj_mask], y_true[..., 1][obj_mask]) * box_loss_scale) + loss_x = torch.mean(self.BCELoss(x[obj_mask], y_true[..., 0][obj_mask]) * box_loss_scale[obj_mask]) + loss_y = torch.mean(self.BCELoss(y[obj_mask], y_true[..., 1][obj_mask]) * box_loss_scale[obj_mask]) #-----------------------------------------------------------# # 计算宽高调整值的loss #-----------------------------------------------------------# - loss_w = torch.mean(self.MSELoss(w[obj_mask], y_true[..., 2][obj_mask]) * box_loss_scale) - loss_h = torch.mean(self.MSELoss(h[obj_mask], y_true[..., 3][obj_mask]) * box_loss_scale) + loss_w = torch.mean(self.MSELoss(w[obj_mask], y_true[..., 2][obj_mask]) * box_loss_scale[obj_mask]) + loss_h = torch.mean(self.MSELoss(h[obj_mask], y_true[..., 3][obj_mask]) * box_loss_scale[obj_mask]) loss_loc = (loss_x + loss_y + loss_h + loss_w) * 0.1 loss_cls = torch.mean(self.BCELoss(pred_cls[obj_mask], y_true[..., 5:][obj_mask]))