Skip to content

Write a cross_entropy function in pytorch to remove the abnormal nan value

Notifications You must be signed in to change notification settings

chenjun2hao/Self_cross_entropy

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 

Repository files navigation

**问题:**用pytorch的crossentropy做损失函数的时候,迭代几步之后loss为nan的情况, 调试了很久才发现是用attention模块的时候,mask全为0了,导致atten的值全为nan。于是自己写了一个剔除个别异常样本的cross_entropy

详情请看代码,测试输出为:

正常计算loss的输出为自己写的losstensor(1.6094)
官方losstensor(1.6094)

存在个别异常样本的输出为:
自己写的losstensor(7.9369, device='cuda:0', grad_fn=<MeanBackward1>)
官方losstensor(nan, device='cuda:0', grad_fn=<NllLossBackward>)

test data download links state.pth, password:fscg


其他参考解决方案

  1. 在pred_x上加一个很小的量,如1e-10
loss = crossentropy(out+1e-8, target)
  1. 采用更小的学习率
  2. 做梯度裁剪 The recommended thing to do when using ReLUs is to clip the gradient。 参考自here
  3. 还可能是数据有问题 比如这位的.链接

[参考]

About

Write a cross_entropy function in pytorch to remove the abnormal nan value

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages