Skip to content

Commit

Permalink
update rfc of pidnet (PaddlePaddle#722)
Browse files Browse the repository at this point in the history
  • Loading branch information
flytocc authored Nov 6, 2023
1 parent d49112a commit 6126899
Showing 1 changed file with 37 additions and 17 deletions.
54 changes: 37 additions & 17 deletions rfcs/PaddleSeg/pidnet.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

| 任务名称 | 轻量语义分割网络PIDNet |
|----------------------------------------------------------|----------------------|
| 提交作者<input type="checkbox" class="rowselector hidden"> | Asthestarsfalll |
| 提交作者<input type="checkbox" class="rowselector hidden"> | [Asthestarsfalll](https://github.com/Asthestarsfalll) [flytocc](https://github.com//flytocc) |
| 提交时间<input type="checkbox" class="rowselector hidden"> | 2023-9-26 |
| 版本号 | V1.0 |
| 修改时间<input type="checkbox" class="rowselector hidden"> | 2023-10-27 |
| 版本号 | V2.0 |
| 依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | develop版本 |
| 文件名 | pidnet.md<br> |

Expand All @@ -26,7 +27,9 @@

PIDNet 源码已经开源,地址:https://github.com/XuJiacong/PIDNet

性能表现如下
mmsegmentation的实现:https://github.com/open-mmlab/mmsegmentation

## 性能表现如下

| Model (Cityscapes) | Val (% mIOU) | Test (% mIOU)| FPS |
|:-:|:-:|:-:|:-:|
Expand All @@ -41,11 +44,16 @@ PIDNet 源码已经开源,地址:https://github.com/XuJiacong/PIDNet

# 四、对比分析

参考官方原码实现即可。
精度:官方版本和mmsegmentation版本基本相同。

数据增强:mmsegmentation较官方版本,增加了`PhotoMetricDistortion`,并在`RandomCrop`中设置了`category_max_ratio=0.75`

# 五、设计思路与实现方案

## 总体思路

以官方版本为主要参考,尽可能在不使用`PhotoMetricDistortion`增强的情况下,复现论文中的精度。

### Edge label 实现

PIDNet 中需要额外生成 edge label 以对边缘部分进行监督,虽然 PaddleSeg 中已经内置了 edge label 的方式,但是具体实现细节差距较大,因此需要单独添加该生成方法。
Expand All @@ -69,34 +77,46 @@ class AddEdgeLabel:
### LOSS 对齐

PIDNet 中使用了多个loss,其中 sem_loss 为 cross_entropy 和 ohem 的组合,与 PaddleSeg 有以下冲突:
1. cross_entropy 的 reduction 参数为 False,即维持输入的形状;
2. ohem 中使用了 class weight。
#### 1. 修复cross_entropy当`weight is not None``avg_non_ignore=False`时的bug;

考虑修改如下:
为 PaddleSeg 的CrossEntropyLoss 添加 use_post_process 的参数用于控制是否平均输出。
将下面这个判断:
```python
if self.use_post_process:
return self._post_process_loss(logit, label, semantic_weights, loss)
return loss
if self.weight is not None:
...
```
修改为:
```python
if self.avg_non_ignore and self.weight is not None:
...
```

为 PaddleSeg 的 OhemCrossEntropyLoss 添加 weight 参数

#### 2. ohem 中使用了 class weight。

将下面的代码:
```python
if self.weight is not None:
loss = F.cross_entropy(
logit, label, weight=self.weight, ignore_index=self.ignore_index, axis=1)
else:
loss = F.softmax_with_cross_entropy(
logit, label, ignore_index=self.ignore_index, axis=1)
```
修改为:
```python
loss = F.cross_entropy(
logit, label, ignore_index=self.ignore_index, axis=1,
weight=self.weight, reduction='none')
```
ps:已验证当`self.weight is None`时,修改前后结果相同。


# 六、测试和验收的考量

达到论文Table.6中的指标,进行TIPC验证lite train lite infer 链条,参考PR提交规范提交代码PR到ppseg中。
达到论文Table.6中的`PIDNet-S`的val精度`78.8`指标,进行TIPC验证lite train lite infer 链条,参考PR提交规范提交代码PR到ppseg中。

# 七、影响面

对其他模块没有影响。
需要评估对cross_entropy的bug修复,是否会影响现在其他模型。

初步评估:现在其他模型不存在`weight is None``avg_non_ignore=False`的情况,所以对现在其他模型无影响。

# 八、排期规划

Expand Down

0 comments on commit 6126899

Please sign in to comment.