Skip to content

Commit

Permalink
添加新算法文档添加loss 的module样例
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Jul 2, 2020
1 parent 52b642b commit 263ca69
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
19 changes: 19 additions & 0 deletions doc/添加新算法.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,25 @@ PytorchOCR将网络划分为三部分
'其他的loss': value # 组成总loss的子loss
}
```
loss module 的形式如下
```python
class ModuleName(nn.Module):
def __init__(self, *args,**kwargs):
pass

def forward(self, pred, batch):
"""
:param pred:
:param batch: bach为一个dict{
'其他计算loss所需的输入':'vaue'
}
:return:
"""
# 计算loss
loss_dict = {'loss':loss,'other_sub_loss':value}
return loss_dict
```

### 配置文件

Expand Down
10 changes: 5 additions & 5 deletions torchocr/networks/losses/DBLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ def forward(self, pred, batch):

loss_shrink_maps = self.bce_loss(shrink_maps, batch['shrink_map'], batch['shrink_mask'])
loss_threshold_maps = self.l1_loss(threshold_maps, batch['threshold_map'], batch['threshold_mask'])
metrics = dict(loss_shrink_maps=loss_shrink_maps, loss_threshold_maps=loss_threshold_maps)
loss_dict = dict(loss_shrink_maps=loss_shrink_maps, loss_threshold_maps=loss_threshold_maps)
if pred.size()[1] > 2:
loss_binary_maps = self.dice_loss(binary_maps, batch['shrink_map'], batch['shrink_mask'])
metrics['loss_binary_maps'] = loss_binary_maps
loss_dict['loss_binary_maps'] = loss_binary_maps
loss_all = self.alpha * loss_shrink_maps + self.beta * loss_threshold_maps + loss_binary_maps
metrics['loss'] = loss_all
loss_dict['loss'] = loss_all
else:
metrics['loss'] = loss_shrink_maps
loss_dict['loss'] = loss_shrink_maps

return metrics
return loss_dict

0 comments on commit 263ca69

Please sign in to comment.