Skip to content

Commit

Permalink
add ppocrv3_rec (PaddlePaddle#1575)
Browse files Browse the repository at this point in the history
* add ppocrv3_rec

* fix eval.py
  • Loading branch information
zzjjay authored Dec 6, 2022
1 parent 22ec8ea commit ef6a8f2
Show file tree
Hide file tree
Showing 5 changed files with 6,769 additions and 56 deletions.
26 changes: 17 additions & 9 deletions example/auto_compression/ocr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
| 模型 | 策略 | Metric | GPU 耗时(ms) | ARM CPU 耗时(ms) | 配置文件 | Inference模型 |
|:------:|:------:|:------:|:------:|:------:|:------:|:------:|
| 中文PPOCRV3-det | Baseline | 84.57 | - | - | - | [Model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar) |
| 中文PPOCRV3-det | 量化+蒸馏 | 83.4 | - | - | [Config](./configs/ppocrv3_det_qat_dist.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/PPOCRV3_det_QAT.tar) |

| 中文PPOCRV3-det | 量化+蒸馏 | 85.01 | - | - | [Config](./configs/ppocrv3_det_qat_dist.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/OCR/PPOCRV3_det_QAT.tar) |
| 中文PPOCRV3-rec | Baseline | 76.48 | - | - | - | [Model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar) |
| 中文PPOCRV3-rec | 量化+蒸馏 | 73.23 | - | - | [Config](./configs/ppocrv3_rec_qat_dist.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/OCR/PPOCRV3_rec_QAT.tar) |
> PPOCRV3-det 的测试指标为 hmean,PPOCRV3-rec的测试指标为 accuracy.
## 3. 自动压缩流程

#### 3.1 准备环境
### 3.1 准备环境

- python >= 3.6
- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
Expand All @@ -45,15 +47,21 @@ pip install paddlepaddle-gpu
pip install paddleslim
```

#### 3.2 准备数据集
下载PaddleOCR:
```shell
git clone -b release/2.6 https://github.com/PaddlePaddle/PaddleOCR.git
```
> 下载 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR.git) 的目的只是为了直接使用 PaddleOCR 中的 Dataloader 组件和精度评估模块,不涉及模型组网等。通过 `pip install paddleocr` 安装的 paddleocr 只有预测代码,没有数据集读取和精度评估的部分,因此需要下载 PaddleOCR 库。
### 3.2 准备数据集
公开数据集可参考[OCR数据集](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/dataset/ocr_datasets.md)

注意:使用不同的数据集需要修改配置文件中`dataset`中数据路径和数据处理部分。
> 注意:使用不同的数据集需要修改配置文件中`dataset`中数据路径和数据处理部分。
#### 3.3 准备预测模型
### 3.3 准备预测模型
预测模型的格式为:`model.pdmodel``model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。

注:其他像`__model__``__params__`分别对应`model.pdmodel``model.pdiparams`文件。
> 注:其他像`__model__``__params__`分别对应`model.pdmodel``model.pdiparams`文件。
可在[PaddleOCR模型库](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/models_list.md)中直接获取Inference模型,具体可参考下方获取中文PPOCRV3检测模型示例:

Expand Down Expand Up @@ -94,13 +102,13 @@ python eval.py --config_path='./configs/ppocrv3_det_qat_dist.yaml'
```

## 4.预测部署
#### 4.1 Python预测推理
### 4.1 Python预测推理

环境配置:若使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python)

Python预测引擎推理可参考[基于Python预测引擎推理](https://github.com/PaddlePaddle/PaddleOCR/blob/9cdab61d909eb595af849db885c257ca8c74cb57/doc/doc_ch/inference_ppocr.md)

#### 4.2 PaddleLite端侧部署
### 4.2 PaddleLite端侧部署
PaddleLite端侧部署可参考:
- [Paddle Lite部署](https://github.com/PaddlePaddle/PaddleOCR/tree/9cdab61d909eb595af849db885c257ca8c74cb57/deploy/lite)

Expand Down
109 changes: 109 additions & 0 deletions example/auto_compression/ocr/configs/ppocrv3_rec_qat_dist.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
Global:
model_dir: ch_PP-OCRv3_rec_infer
model_filename: inference.pdmodel
params_filename: inference.pdiparams
model_type: rec
algorithm: SVTR
character_dict_path: ./ppocr_keys_v1.txt
max_text_length: &max_text_length 25
use_space_char: true

Distillation:
alpha: [1.0, 1.0]
loss: ['skd', 'l2']
node:
- ['linear_43.tmp_1']
- ['linear_43.tmp_1']

Quantization:
use_pact: true
activation_bits: 8
is_full_quantize: false
onnx_format: True
activation_quantize_type: moving_average_abs_max
weight_quantize_type: channel_wise_abs_max
not_quant_pattern:
- skip_quant
quantize_op_types:
- conv2d
- depthwise_conv2d
weight_bits: 8

TrainConfig:
epochs: 10
eval_iter: 2000
logging_iter: 100
learning_rate:
type: CosineAnnealingDecay
learning_rate: 0.00005
optimizer_builder:
optimizer:
type: Adam
weight_decay: 5.0e-05

PostProcess:
name: CTCLabelDecode

Metric:
name: RecMetric
main_indicator: acc
ignore_space: False

Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecConAug:
prob: 0.5
ext_data_num: 2
image_shape: [48, 320, 3]
max_text_length: *max_text_length
- RecAug:
- MultiLabelEncode:
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_sar
- length
- valid_ratio
loader:
shuffle: true
batch_size_per_card: 64
drop_last: true
num_workers: 0

Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_sar
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 64
num_workers: 0
46 changes: 10 additions & 36 deletions example/auto_compression/ocr/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import logging
import numpy as np
import argparse
from tqdm import tqdm
import paddle
from paddleslim.common import load_config as load_slim_config
from paddleslim.common import get_logger
from paddleslim.auto_compression import AutoCompression

import sys
sys.path.append('../PaddleOCR')
from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric

Expand All @@ -37,7 +34,7 @@ def argsparser():
parser.add_argument(
'--config_path',
type=str,
default='./image_classification/configs/eval.yaml',
default='./configs/ppocrv3_det_qat_dist.yaml',
help="path of compression strategy config.")
parser.add_argument(
'--model_dir',
Expand All @@ -47,20 +44,6 @@ def argsparser():
return parser


extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"
]


def sample_generator(loader):
def __reader__():
for indx, data in enumerate(loader):
images = np.array(data[0])
yield images

return __reader__


def eval():
devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices)
Expand All @@ -70,43 +53,34 @@ def eval():
exe,
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"])
print('Loaded model from: {}'.format(global_config["model_dir"]))
logger.info('Loaded model from: {}'.format(global_config["model_dir"]))

val_loader = build_dataloader(all_config, 'Eval', devices, logger)
post_process_class = build_post_process(all_config['PostProcess'],
global_config)
eval_class = build_metric(all_config['Metric'])
model_type = global_config['model_type']
extra_input = True if global_config[
'algorithm'] in extra_input_models else False

with tqdm(
total=len(val_loader),
bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for batch_id, batch in enumerate(val_loader):
images = batch[0]
if extra_input:
preds = exe.run(
val_program,
feed={feed_target_names[0]: images,
'data': batch[1:]},
fetch_list=fetch_targets)
else:
preds = exe.run(val_program,
feed={feed_target_names[0]: images},
fetch_list=fetch_targets)
preds, = exe.run(val_program,
feed={feed_target_names[0]: images},
fetch_list=fetch_targets)

batch_numpy = []
for item in batch:
batch_numpy.append(np.array(item))

if model_type == 'det':
preds_map = {'maps': preds[0]}
preds_map = {'maps': preds}
post_result = post_process_class(preds_map, batch_numpy[1])
eval_class(post_result, batch_numpy)
elif model_type == 'rec':
post_result = post_process_class(preds[0], batch_numpy[1])
post_result = post_process_class(preds, batch_numpy[1])
eval_class(post_result, batch_numpy)

t.update()
Expand Down
Loading

0 comments on commit ef6a8f2

Please sign in to comment.