Skip to content

Commit

Permalink
[Feature] Add multi class focal loss (PaddlePaddle#1957)
Browse files Browse the repository at this point in the history
  • Loading branch information
juncaipeng authored Apr 5, 2022
1 parent de9c024 commit 03b8297
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions paddleseg/models/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,53 @@ def forward(self, logit, label):
avg_loss = paddle.sum(loss) / (
paddle.sum(paddle.cast(mask != 0., 'int32')) * class_num + self.EPS)
return avg_loss


@manager.LOSSES.add_component
class MultiClassFocalLoss(nn.Layer):
"""
The implement of focal loss for multi class.
Args:
alpha (float, list, optional): The alpha of focal loss. alpha is the weight
of class 1, 1-alpha is the weight of class 0. Default: 0.25
gamma (float, optional): The gamma of Focal Loss. Default: 2.0
ignore_index (int64, optional): Specifies a target value that is ignored
and does not contribute to the input gradient. Default ``255``.
"""

def __init__(self, num_class, alpha=1.0, gamma=2.0, ignore_index=255):
super().__init__()
self.num_class = num_class
self.alpha = alpha
self.gamma = gamma
self.ignore_index = ignore_index
self.EPS = 1e-10

def forward(self, logit, label):
"""
Forward computation.
Args:
logit (Tensor): Logit tensor, the data type is float32, float64. Shape is
(N, C, H, W), where C is number of classes.
label (Tensor): Label tensor, the data type is int64. Shape is (N, W, W),
where each value is 0 <= label[i] <= C-1.
Returns:
(Tensor): The average loss.
"""
assert logit.ndim == 4, "The ndim of logit should be 4."
assert label.ndim == 3, "The ndim of label should be 3."

logit = paddle.transpose(logit, [0, 2, 3, 1])
label = label.astype('int64')
ce_loss = F.cross_entropy(
logit, label, ignore_index=self.ignore_index, reduction='none')

pt = paddle.exp(-ce_loss)
focal_loss = self.alpha * ((1 - pt)**self.gamma) * ce_loss

mask = paddle.cast(label != self.ignore_index, 'float32')
focal_loss *= mask
avg_loss = paddle.mean(focal_loss) / (paddle.mean(mask) + self.EPS)
return avg_loss

0 comments on commit 03b8297

Please sign in to comment.