Skip to content

Commit

Permalink
[distill]add distillation losses (PaddlePaddle#789)
Browse files Browse the repository at this point in the history
  • Loading branch information
littletomatodonkey authored Jul 26, 2021
1 parent c9c0e83 commit dfe5d3f
Show file tree
Hide file tree
Showing 4 changed files with 1,110 additions and 0 deletions.
70 changes: 70 additions & 0 deletions paddleslim/dygraph/dist/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,73 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import paddle
import paddle.nn as nn

from . import basic_loss
from . import distillation_loss

from .basic_loss import L1Loss
from .basic_loss import L2Loss
from .basic_loss import SmoothL1Loss
from .basic_loss import CELoss
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
from .basic_loss import RKdAngle, RkdDistance

from .distillation_loss import DistillationDistanceLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationRKDLoss


class CombinedLoss(nn.Layer):
"""
CombinedLoss: a combination of loss function.
Args:
loss_config_list: a config list used to build loss function. A demo is as follows,
which is used to calculate dml loss between Student output and
Teacher output. Parameter weight is needed for the loss weight.
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
model_name_pairs:
- ["Student", "Teacher"]
"""

def __init__(self, loss_config_list=None):
super().__init__()
loss_config_list = copy.deepcopy(loss_config_list)
self.loss_func = []
self.loss_weight = []
assert isinstance(loss_config_list, list), (
'operator config should be a list')
supported_loss_list = basic_loss.__all__ + distillation_loss.__all__
for config in loss_config_list:
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
name = list(config)[0]
assert name in supported_loss_list, \
"loss name must be in {} but got: {}".format(name, supported_loss_list)
param = config[name]
assert "weight" in param, "weight must be in param, but param just contains {}".format(
param.keys())
self.loss_weight.append(param.pop("weight"))
self.loss_func.append(eval(name)(**param))

def forward(self, input, batch, **kargs):
loss_dict = {}
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch, **kargs)
weight = self.loss_weight[idx]
if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss * weight}
else:
loss = {
"{}_{}".format(key, idx): loss[key] * weight
for key in loss
}
loss_dict.update(loss)
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
return loss_dict
207 changes: 207 additions & 0 deletions paddleslim/dygraph/dist/losses/basic_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from paddle.nn import L1Loss
from paddle.nn import MSELoss as L2Loss
from paddle.nn import SmoothL1Loss

__all__ = [
"CELoss",
"DMLLoss",
"DistanceLoss",
"RKdAngle",
"RkdDistance",
]


class CELoss(nn.Layer):
"""
CELoss: cross entropy loss
Args:
epsilon(float | None): label smooth epsilon. If it is None or not in range (0, 1),
then label smooth will not be used.
label_act(string | None): activation function, it works when the label is also the logits
rather than the groundtruth label.
axis(int): axis used to calculate cross entropy loss.
"""

def __init__(self, epsilon=None, label_act="softmax", axis=-1):
super().__init__()
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
assert label_act in ["softmax", None]
if epsilon is not None and (epsilon >= 1 or epsilon <= 0):
epsilon = None
self.epsilon = epsilon
self.label_act = label_act
self.axis = axis

def _labelsmoothing(self, target, class_num):
if target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
return soft_target

def forward(self, x, label):
assert len(x.shape) == len(label.shape), \
"x and label shape length should be same but got {} for x.shape and {} for label.shape".format(x.shape, label.shape)
if self.epsilon is not None:
class_num = x.shape[-1]
label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(x, axis=self.axis)
loss = paddle.sum(x * label, axis=self.axis)
else:
if label.shape[self.axis] == x.shape[self.axis]:
if self.label_act == "softmax":
label = F.softmax(label, axis=self.axis)
soft_label = True
else:
soft_label = False
loss = F.cross_entropy(
x, label=label, soft_label=soft_label, axis=self.axis)
loss = loss.mean()
return loss


class DMLLoss(nn.Layer):
"""
DMLLoss
Args:
act(string | None): activation function used to activate the input tensor
axis(int): axis used to build activation function
"""

def __init__(self, act=None, axis=-1):
super().__init__()
if act is not None:
assert act in ["softmax", "sigmoid"]
if act == "softmax":
self.act = nn.Softmax(axis=axis)
elif act == "sigmoid":
self.act = nn.Sigmoid()
else:
self.act = None

def forward(self, out1, out2):
if self.act is not None:
out1 = self.act(out1)
out2 = self.act(out2)

log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2)
loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0
return loss


class DistanceLoss(nn.Layer):
"""
DistanceLoss
Args:
mode: loss mode
kargs(dict): used to build corresponding loss function, for more details, please
refer to:
L1loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/L1Loss_cn.html#l1loss
L2Loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/MSELoss_cn.html#mseloss
SmoothL1Loss: https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/SmoothL1Loss_cn.html#smoothl1loss
"""

def __init__(self, mode="l2", **kargs):
super().__init__()
assert mode in ["l1", "l2", "smooth_l1"]
if mode == "l1":
self.loss_func = nn.L1Loss(**kargs)
elif mode == "l2":
self.loss_func = nn.MSELoss(**kargs)
elif mode == "smooth_l1":
self.loss_func = nn.SmoothL1Loss(**kargs)

def forward(self, x, y):
return self.loss_func(x, y)


def pdist(e, squared=False, eps=1e-12):
e_square = e.pow(2).sum(axis=1)
prod = paddle.mm(e, e.t())
res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clip(
min=eps)

if not squared:
res = res.sqrt()

return res


class RKdAngle(nn.Layer):
"""
RKdAngle loss, see https://arxiv.org/abs/1904.05068
"""

def __init__(self):
super().__init__()

def forward(self, student, teacher):
# reshape for feature map distillation
bs = student.shape[0]
student = student.reshape([bs, -1])
teacher = teacher.reshape([bs, -1])

td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
norm_td = F.normalize(td, p=2, axis=2)
t_angle = paddle.bmm(norm_td, norm_td.transpose([0, 2, 1])).reshape(
[-1, 1])

sd = (student.unsqueeze(0) - student.unsqueeze(1))
norm_sd = F.normalize(sd, p=2, axis=2)
s_angle = paddle.bmm(norm_sd, norm_sd.transpose([0, 2, 1])).reshape(
[-1, 1])
loss = F.smooth_l1_loss(s_angle, t_angle, reduction='mean')
return loss


class RkdDistance(nn.Layer):
"""
RkdDistance loss, see https://arxiv.org/abs/1904.05068
Args:
eps(float): epsilon for the pdist function
"""

def __init__(self, eps=1e-12):
super().__init__()
self.eps = eps

def forward(self, student, teacher):
bs = student.shape[0]
student = student.reshape([bs, -1])
teacher = teacher.reshape([bs, -1])

t_d = pdist(teacher, squared=False, eps=self.eps)
mean_td = t_d.mean()
t_d = t_d / (mean_td + self.eps)

d = pdist(student, squared=False, eps=self.eps)
mean_d = d.mean()
d = d / (mean_d + self.eps)

loss = F.smooth_l1_loss(d, t_d, reduction="mean")
return loss
Loading

0 comments on commit dfe5d3f

Please sign in to comment.