Skip to content

Commit

Permalink
Fix eval function in segmentation demo of ACT (PaddlePaddle#1218)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghaoshuang authored Jul 1, 2022
1 parent dbdaa38 commit 3a026b6
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 4 deletions.
10 changes: 10 additions & 0 deletions demo/auto_compression/semantic_segmentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@
- PP-HumanSeg-Lite数据集

- 数据集:AISegment + PP-HumanSeg14K + 内部自建数据集。其中 AISegment 是开源数据集,可从[链接](https://github.com/aisegmentcn/matting_human_datasets)处获取;PP-HumanSeg14K 是 PaddleSeg 自建数据集,可从[官方渠道](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/contrib/PP-HumanSeg/paper.md#pp-humanseg14k-a-large-scale-teleconferencing-video-dataset)获取;内部数据集不对外公开。
- 示例数据集: 用于快速跑通人像分割的压缩和推理流程, 不能用该数据集复现 benckmark 表中的压缩效果。 [下载链接](https://paddleseg.bj.bcebos.com/humanseg/data/mini_supervisely.zip)

- PP-Liteseg,HRNet,UNet,Deeplabv3-ResNet50数据集

- cityscapes: 请从[cityscapes官网](https://www.cityscapes-dataset.com/login/)下载完整数据
- 示例数据集: cityscapes数据集的一个子集,用于快速跑通压缩和推理流程,不能用该数据集复现 benchmark 表中的压缩效果。[下载链接](https://bj.bcebos.com/v1/paddle-slim-models/data/mini_cityscapes/mini_cityscapes.tar)

下面将以开源数据集为例介绍如何对PP-HumanSeg-Lite进行自动压缩。

Expand Down Expand Up @@ -85,6 +87,14 @@ pip install paddleseg

开发者可下载开源数据集 (如[AISegment](https://github.com/aisegmentcn/matting_human_datasets)) 或自定义语义分割数据集。请参考[PaddleSeg数据准备文档](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/data/marker/marker_cn.md)来检查、对齐数据格式即可。

可以通过以下命令下载人像分割示例数据:

```shell
cd ./data
python download_data.py mini_humanseg

```

#### 3.3 准备预测模型

预测模型的格式为:`model.pdmodel``model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。
Expand Down
10 changes: 6 additions & 4 deletions demo/auto_compression/semantic_segmentation/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
ori_shape,
eval_dataset.transforms.transforms,
mode='bilinear')

pred = paddle.argmax(
paddle.to_tensor(logit), axis=1, keepdim=True, dtype='int32')
pred = paddle.to_tensor(logit)
if len(
pred.shape
) == 4: # for humanseg model whose prediction is distribution but not class id
pred = paddle.argmax(pred, axis=1, keepdim=True, dtype='int32')

intersect_area, pred_area, label_area = metrics.calculate_area(
pred,
Expand Down Expand Up @@ -166,7 +168,7 @@ def gen():
if __name__ == '__main__':

args = parse_args()

paddle.enable_static()
# step1: load dataset config and create dataloader
data_cfg = PaddleSegDataConfig(args.dataset_config)
train_dataset = data_cfg.train_dataset
Expand Down
15 changes: 15 additions & 0 deletions paddleslim/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import download
163 changes: 163 additions & 0 deletions paddleslim/utils/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
import shutil
import sys
import tarfile
import time
import zipfile

import requests

lasttime = time.time()
FLUSH_INTERVAL = 0.1


def progress(str, end=False):
global lasttime
if end:
str += "\n"
lasttime = 0
if time.time() - lasttime >= FLUSH_INTERVAL:
sys.stdout.write("\r%s" % str)
lasttime = time.time()
sys.stdout.flush()


def _download_file(url, savepath, print_progress):
if print_progress:
print("Connecting to {}".format(url))
r = requests.get(url, stream=True, timeout=15)
total_length = r.headers.get('content-length')

if total_length is None:
with open(savepath, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(savepath, 'wb') as f:
dl = 0
total_length = int(total_length)
starttime = time.time()
if print_progress:
print("Downloading %s" % os.path.basename(savepath))
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
if print_progress:
done = int(50 * dl / total_length)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * dl) / total_length))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)


def _uncompress_file_zip(filepath, extrapath):
files = zipfile.ZipFile(filepath, 'r')
filelist = files.namelist()
rootpath = filelist[0]
total_num = len(filelist)
for index, file in enumerate(filelist):
files.extract(file, extrapath)
yield total_num, index, rootpath
files.close()
yield total_num, index, rootpath


def _uncompress_file_tar(filepath, extrapath, mode="r:gz"):
files = tarfile.open(filepath, mode)
filelist = files.getnames()
total_num = len(filelist)
rootpath = filelist[0]
for index, file in enumerate(filelist):
files.extract(file, extrapath)
yield total_num, index, rootpath
files.close()
yield total_num, index, rootpath


def _uncompress_file(filepath, extrapath, delete_file, print_progress):
if print_progress:
print("Uncompress %s" % os.path.basename(filepath))

if filepath.endswith("zip"):
handler = _uncompress_file_zip
elif filepath.endswith("tgz"):
handler = functools.partial(_uncompress_file_tar, mode="r:*")
else:
handler = functools.partial(_uncompress_file_tar, mode="r")

for total_num, index, rootpath in handler(filepath, extrapath):
if print_progress:
done = int(50 * float(index) / total_num)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * index) / total_num))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)

if delete_file:
os.remove(filepath)

return rootpath


def download_file_and_uncompress(url,
savepath=None,
extrapath=None,
extraname=None,
print_progress=True,
cover=False,
delete_file=True):
if savepath is None:
savepath = "."

if extrapath is None:
extrapath = "."

savename = url.split("/")[-1]
if not os.path.exists(savepath):
os.makedirs(savepath)

savepath = os.path.join(savepath, savename)
savename = ".".join(savename.split(".")[:-1])
savename = os.path.join(extrapath, savename)
extraname = savename if extraname is None else os.path.join(extrapath,
extraname)

if cover:
if os.path.exists(savepath):
shutil.rmtree(savepath)
if os.path.exists(savename):
shutil.rmtree(savename)
if os.path.exists(extraname):
shutil.rmtree(extraname)

if not os.path.exists(extraname):
if not os.path.exists(savename):
if not os.path.exists(savepath):
_download_file(url, savepath, print_progress)

if (not tarfile.is_tarfile(savepath)) and (
not zipfile.is_zipfile(savepath)):
if not os.path.exists(extraname):
os.makedirs(extraname)
shutil.move(savepath, extraname)
return extraname

savename = _uncompress_file(savepath, extrapath, delete_file,
print_progress)
savename = os.path.join(extrapath, savename)
shutil.move(savename, extraname)
return extraname

0 comments on commit 3a026b6

Please sign in to comment.