Skip to content

Commit

Permalink
fix inconsistent in loss function and training
Browse files Browse the repository at this point in the history
  • Loading branch information
coderbyr committed Aug 14, 2019
1 parent 7a1b216 commit e2786c5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ def forward(self, logits, target,
else:
if is_multi:
assert self.loss_type in [LossType.BCE_WITH_LOGITS,
LossType.SIGMOID_FOCAL_CROSS_ENTROPY]
LossType.SIGMOID_FOCAL_CROSS_ENTROPY]
else:
if self.loss_type not in [LossType.SOFTMAX_CROSS_ENTROPY,
LossType.SOFTMAX_FOCAL_CROSS_ENTROPY]
LossType.SOFTMAX_FOCAL_CROSS_ENTROPY]:
target = torch.eye(self.label_size)[target].to(device)
return self.criterion(logits, target)

Expand Down

0 comments on commit e2786c5

Please sign in to comment.