Skip to content

Commit

Permalink
【关于 早停法 EarlyStopping 】那些你不知道的事
Browse files Browse the repository at this point in the history
【关于 标签平滑法 LabelSmoothing 】那些你不知道的事
“
  • Loading branch information
km1994 committed Jun 2, 2021
1 parent 4ccc00a commit cfaffab
Show file tree
Hide file tree
Showing 22 changed files with 349 additions and 4 deletions.
2 changes: 2 additions & 0 deletions NLPinterview/KG/KBQA/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
>
> 项目地址:https://github.com/km1994/nlp_paper_study
>
> 面筋:https://github.com/km1994/NLP-Interview-Notes
>
> 个人介绍:大佬们好,我叫杨夕,该项目主要是本人在研读顶会论文和复现经典论文过程中,所见、所思、所想、所闻,可能存在一些理解错误,希望大佬们多多指正。
![](img/微信截图_20210204081440.png)
Expand Down
4 changes: 4 additions & 0 deletions NLPinterview/PreTraining/bert/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

> 作者:杨夕
>
> NLP 论文读书笔记:https://github.com/km1994/nlp_paper_study
>
> 面经:https://github.com/km1994/NLP-Interview-Notes/blob/main/NLPinterview/PreTraining/bert/readme.md
>
> 论文链接:https://arxiv.org/pdf/1810.04805.pdf
>
> 代码链接:https://github.com/google-research/bert
Expand Down
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
- [5.1 【关于 少样本问题】那些你不知道的事](#51-关于-少样本问题那些你不知道的事)
- [5.2 【关于 脏数据】那些你不知道的事](#52-关于-脏数据那些你不知道的事)
- [5.3 【关于 炼丹炉】那些你不知道的事](#53-关于-炼丹炉那些你不知道的事)
- [5.4 【关于 早停法 EarlyStopping 】那些你不知道的事](#54-关于-早停法-earlystopping-那些你不知道的事)
- [5.5 【关于 标签平滑法 LabelSmoothing 】那些你不知道的事](#55-关于-标签平滑法-labelsmoothing-那些你不知道的事)
- [六、【关于 Python 】那些你不知道的事](#六关于-python-那些你不知道的事)
- [七、【关于 Tensorflow 】那些你不知道的事](#七关于-tensorflow-那些你不知道的事)

Expand Down Expand Up @@ -1222,6 +1224,21 @@
- [【关于 batch_size设置】那些你不知道的事](Trick/batch_size/)
- [一、训练模型时,batch_size的设置,学习率的设置?](Trick/batch_size/readme.md#一训练模型时batch_size的设置学习率的设置)

#### 5.4 [【关于 早停法 EarlyStopping 】那些你不知道的事](Trick/EarlyStopping/)

- [【关于 早停法 EarlyStopping 】那些你不知道的事](Trick/EarlyStopping/)
- [一、 为什么要用 早停法 EarlyStopping?](#一-为什么要用-早停法-earlystopping)
- [二、 早停法 EarlyStopping 是什么?](#二-早停法-earlystopping-是什么)
- [三、早停法 torch 版本怎么实现?](#三早停法-torch-版本怎么实现)

#### 5.5 [【关于 标签平滑法 LabelSmoothing 】那些你不知道的事](Trick/LabelSmoothing/)

- [【关于 标签平滑法 LabelSmoothing 】那些你不知道的事](Trick/LabelSmoothing?/)
- [一、为什么要有 标签平滑法 LabelSmoothing?](#一为什么要有-标签平滑法-labelsmoothing)
- [二、 标签平滑法 是什么?](#二-标签平滑法-是什么)
- [三、 标签平滑法 torch 怎么复现?](#三-标签平滑法-torch-怎么复现)


### 六、[【关于 Python 】那些你不知道的事](python/)

- [【关于 Python 】那些你不知道的事](python/)
Expand Down
Binary file added Trick/EarlyStopping/img/20210523220743.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
98 changes: 98 additions & 0 deletions Trick/EarlyStopping/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# 【关于 早停法 EarlyStopping 】那些你不知道的事

> 作者:杨夕
>
> 论文学习项目地址:https://github.com/km1994/nlp_paper_study
>
> 《NLP 百面百搭》地址:https://github.com/km1994/NLP-Interview-Notes
>
> 个人介绍:大佬们好,我叫杨夕,该项目主要是本人在研读顶会论文和复现经典论文过程中,所见、所思、所想、所闻,可能存在一些理解错误,希望大佬们多多指正。
>
![](img/微信截图_20210301212242.png)

> NLP && 推荐学习群【人数满了,加微信 blqkm601 】
![](img/20210523220743.png)

- [【关于 早停法 EarlyStopping 】那些你不知道的事](#关于-早停法-earlystopping-那些你不知道的事)
- [一、 为什么要用 早停法 EarlyStopping?](#一-为什么要用-早停法-earlystopping)
- [二、 早停法 EarlyStopping 是什么?](#二-早停法-earlystopping-是什么)
- [三、早停法 torch 版本怎么实现?](#三早停法-torch-版本怎么实现)

## 一、 为什么要用 早停法 EarlyStopping?

模型训练过程中,训练 loss 和 验证 loss 在训练初期都是 呈下降趋势;当训练到达一定程度之后, 验证 loss 并非继续随 训练 loss 一样下降,而是 出现上升的趋势,此时,如果继续往下训练,容易出现 模型性能下降问题,也就是我们所说的过拟合问题。

那么,有什么办法可以避免模型出现该问题呢?

这个就是本节 所介绍的方法 —— 早停法

## 二、 早停法 EarlyStopping 是什么?

早停法 就是在训练中计算模型在验证集上的表现,当模型在验证集上的表现开始下降的时候,停止训练,这样就能避免模型由于继续训练而导致过拟合的问题。所以说 早停法 结合交叉验证法可以防止模型过拟合。

## 三、早停法 torch 版本怎么实现?

```python
import torch
import numpy as np
# 早停法
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta

def __call__(self, val_loss, model, model_path):
'''
功能:早停法 计算函数
input:
val_loss 验证损失
model 模型
model_path 模型保存地址
'''
score = -val_loss

if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model, model_path)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model, model_path)
self.counter = 0

# 功能:当验证损失减少时保存模型
def save_checkpoint(self, val_loss, model, model_path):
'''
功能:当验证损失减少时保存模型
input:
val_loss 验证损失
model 模型
model_path 模型保存地址
'''
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
# torch.save(model.state_dict(), 'checkpoint_loss.pt')
torch.save(model, open(model_path, "wb"))
self.val_loss_min = val_loss
```
Binary file added Trick/LabelSmoothing/img/20210523220743.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
99 changes: 99 additions & 0 deletions Trick/LabelSmoothing/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# 【关于 标签平滑法 LabelSmoothing 】那些你不知道的事

> 作者:杨夕
>
> 论文学习项目地址:https://github.com/km1994/nlp_paper_study
>
> 《NLP 百面百搭》地址:https://github.com/km1994/NLP-Interview-Notes
>
> 个人介绍:大佬们好,我叫杨夕,该项目主要是本人在研读顶会论文和复现经典论文过程中,所见、所思、所想、所闻,可能存在一些理解错误,希望大佬们多多指正。
>
![](img/微信截图_20210301212242.png)

> NLP && 推荐学习群【人数满了,加微信 blqkm601 】
![](img/20210523220743.png)

- [【关于 标签平滑法 LabelSmoothing 】那些你不知道的事](#关于-标签平滑法-labelsmoothing-那些你不知道的事)
- [一、为什么要有 标签平滑法 LabelSmoothing?](#一为什么要有-标签平滑法-labelsmoothing)
- [二、 标签平滑法 是什么?](#二-标签平滑法-是什么)
- [三、 标签平滑法 torch 怎么复现?](#三-标签平滑法-torch-怎么复现)
- [参考](#参考)

## 一、为什么要有 标签平滑法 LabelSmoothing?

- 交叉熵损失函数在多分类任务中存在的问题

多分类任务中,神经网络会输出一个当前数据对应于各个类别的置信度分数,将这些分数通过softmax进行归一化处理,最终会得到当前数据属于每个类别的概率。

然后计算交叉熵损失函数:

![](img/微信截图_20210602203923.png)

训练神经网络时,最小化预测概率和标签真实概率之间的交叉熵,从而得到最优的预测概率分布。最优的预测概率分布是:

![](img/微信截图_20210602204003.png)

**神经网络会促使自身往正确标签和错误标签差值最大的方向学习,在训练数据较少,不足以表征所有的样本特征的情况下,会导致网络过拟合。**

## 二、 标签平滑法 是什么?

label smoothing可以解决上述问题,这是一种正则化策略,主要是通过 soft one-hot 来加入噪声,减少了真实样本标签的类别在计算损失函数时的权重,最终起到抑制过拟合的效果。

![](img/微信截图_20210602204205.png)

增加label smoothing后真实的概率分布有如下改变:

![](img/微信截图_20210602204441.png)

交叉熵损失函数的改变如下:

![](img/微信截图_20210602204518.png)

最优预测概率分布如下:

![](img/微信截图_20210602204551.png)

## 三、 标签平滑法 torch 怎么复现?

```python
import torch.nn as nn
from torch.autograd import Variable
# 标签平滑发
class LabelSmoothing(nn.Module):
def __init__(self, size, smoothing=0.0):
super(LabelSmoothing, self).__init__()
'''
nn.KLDivLoss : KL 散度
功能: 计算input和target之间的KL散度( Kullback–Leibler divergence)
'''
self.criterion = nn.KLDivLoss(size_average=False)
#self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing #if i=y的公式
self.smoothing = smoothing
self.size = size
self.true_dist = None

def forward(self, x, target):
"""
input:
x 表示输入 (N,M)N个样本,M表示总类数,每一个类的概率log P
target表示label(M,)
return:
Loos
"""
assert x.size(1) == self.size
true_dist = x.data.clone()#先深复制过来
true_dist.fill_(self.smoothing / (self.size - 1))#otherwise的公式
# 变成one-hot编码,1表示按列填充,
# target.data.unsqueeze(1)表示索引,confidence表示填充的数字
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
self.true_dist = true_dist
return self.criterion(x, Variable(true_dist, requires_grad=False))
```

## 参考

1. [label smoothing(标签平滑)学习笔记](https://zhuanlan.zhihu.com/p/116466239)
2. [标签平滑&深度学习:Google Brain解释了为什么标签平滑有用以及什么时候使用它](https://zhuanlan.zhihu.com/p/101553787)
Loading

0 comments on commit cfaffab

Please sign in to comment.