Skip to content

Commit

Permalink
fix kldiv when stop grad is trur (PaddlePaddle#5643)
Browse files Browse the repository at this point in the history
  • Loading branch information
littletomatodonkey authored Mar 7, 2022
1 parent db60893 commit 5b33340
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions ppocr/losses/basic_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,15 @@ def __init__(self, act=None, use_log=False):
self.act = None

self.use_log = use_log

self.jskl_loss = KLJSLoss(mode="js")

def _kldiv(self, x, target):
eps = 1.0e-10
loss = target * (paddle.log(target + eps) - x)
# batch mean loss
loss = paddle.sum(loss) / loss.shape[0]
return loss

def forward(self, out1, out2):
if self.act is not None:
out1 = self.act(out1)
Expand All @@ -106,9 +112,8 @@ def forward(self, out1, out2):
# for recognition distillation, log is needed for feature map
log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2)
loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0
loss = (
self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
else:
# for detection distillation log is not needed
loss = self.jskl_loss(out1, out2)
Expand Down

0 comments on commit 5b33340

Please sign in to comment.