Skip to content

Commit

Permalink
Merge AnalysisPTQ & AnalysisQAT to Analysis (PaddlePaddle#1692)
Browse files Browse the repository at this point in the history
  • Loading branch information
RachelXu7 authored Mar 21, 2023
1 parent 2bb09da commit 3b2ed2c
Show file tree
Hide file tree
Showing 14 changed files with 909 additions and 1,134 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# PTQ(Post Training Quantization)量化分析工具详细教程
# 量化分析工具详细教程

## 1. 量化分析工具功能
1. 统计分析(statistical_analyse):
Expand All @@ -13,17 +13,18 @@
- 输入预期精度,直接产出符合预期精度的量化模型。


## 2. paddleslim.quant.AnalysisPTQ 可传入参数解析
## 2. paddleslim.quant.Analysis 可传入参数解析
| **参数名** | **参数释义** |
|-----------------------------|-----------------------------------------|
| model_dir | 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可 |
| float_model_dir | 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可 |
| quant_model_dir | 默认为None,传入的量化模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可; 若不传入,分析工具将使用PTQ进行量化并分析|
| model_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdmodel'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入 |
| params_filename | 默认为None,若model_dir为文件夹名,则必须传入以'.pdiparams'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入 |
| eval_function | 若需要验证精度,需要传入自定义的验证函数;若不传入,精度误差分析将根据Cosine Similarity计算得出 |
| data_loader | 模型校准时使用的数据,DataLoader继承自`paddle.io.DataLoader`。可以直接使用模型套件中的DataLoader,或者根据[paddle.io.DataLoader](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader)自定义所需要的DataLoader |
| save_dir | 分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results`|
| resume | 是否加载中间分析文件,默认为False|
| ptq_config | 可传入的离线量化中的参数,详细可参考[离线量化文档](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_post) |
| quant_config | 可传入的离线量化中的参数,详细可参考[离线量化文档](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_post) |



Expand All @@ -45,7 +46,7 @@ import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.quant.analysis_ptq import AnalysisPTQ
from paddleslim.quant.analysis import Analysis
paddle.enable_static()

class ImageNetDataset(DatasetFolder):
Expand All @@ -72,12 +73,12 @@ image = paddle.static.data(
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=8, return_list=False)

analyzer = AnalysisPTQ(
model_dir="./MobileNetV1_infer",
analyzer = Analysis(
float_model_dir="./MobileNetV1_infer",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_analysis",
ptq_config={
quant_config={
'quantizable_op_type': ["conv2d", "depthwise_conv2d"],
'weight_quantize_type': 'abs_max',
'activation_quantize_type': 'moving_average_abs_max',
Expand Down Expand Up @@ -124,22 +125,17 @@ analyzer.statistical_analyse()
```shell
analyzer.metric_error_analyse()
```
调用该接口,会遍历量化模型中的一层,并计算量化该层后模型的损失。调用该接口时,需要输入Eval Function。会产出所有只量化一层的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。
若不传入quant_model_dir,并且调用该接口,会遍历量化模型中的一层,并计算量化该层后模型的损失。调用该接口时,需要输入Eval Function。会产出所有只量化一层的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。

若传入quant_model_dir,并且调用该接口,会遍历量化模型中的每一层,去掉量化节点并计算当前层不量化的模型精度。调用该接口时,需要输入Eval Function。会产出所有去掉一层量化的模型精度排序,将默认保存在 `./analysis_results/analysis.txt` 中。具体使用可参考[GPT量化训练敏感度分析DEMO](../../../../example/quantization_analysis/GPT/README.md)




**直接产出符合预期精度的目标量化模型**
```shell
analyzer.get_target_quant_model(target_metric=70.0)
analyzer.get_target_quant_model(target_metric=0.70)
```

## 4. 根据分析结果执行离线量化
执行完量化分析工具后,可根据 `analysis.txt` 中的精度排序,在量化中去掉效果较差的层,具体操作为:在调用 `paddleslim.quant.quant_post_static` 时加入参数 `skip_tensor_list`,将需要去掉的层传入即可。


## FAQ:
- 与QAT(Quantization-Aware Training)量化分析工具的区别:与QAT量化分析工具不同的是,PTQ量化分析工具则是加载待量化的原模型,对模型所有层依次进行量化,每次量化一层,进行验证获取精度误差分析。而QAT量化分析工具加载量化训练后的量化模型,遍历所有量化的层,依次去掉量化层,加载Float模型的参数,并进行验证获取精度误差分析。

- PTQ量化分析工具设计的原因:PTQ量化分析工具依次量化模型中的每一层,而不是依次去掉量化层是由于PTQ本身的高效性。依次量化一层进行验证,查看对模型精度的损失十分直观。

- 量化分析工具为什么要区分PTQ和QAT:实验证明PTQ和QAT后的量化模型的敏感层并不完全一致,将两种算法分开,敏感度分析结果更加准确。
98 changes: 0 additions & 98 deletions docs/zh_cn/tutorials/quant/AnalysisQAT.md

This file was deleted.

34 changes: 17 additions & 17 deletions example/post_training_quantization/detection/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from keypoint_utils import keypoint_post_process
from post_process import PPYOLOEPostProcess
from paddleslim.quant.analysis_ptq import AnalysisPTQ
from paddleslim.quant.analysis import Analysis


def argsparser():
Expand Down Expand Up @@ -87,10 +87,11 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
elif isinstance(config['input_list'], dict):
if k in config['input_list'].keys():
data_input[config['input_list'][k]] = np.array(v)
outs = exe.run(compiled_test_program,
feed=data_input,
fetch_list=test_fetch_list,
return_numpy=False)
outs = exe.run(
compiled_test_program,
feed=data_input,
fetch_list=test_fetch_list,
return_numpy=False)
res = {}
if 'arch' in config and config['arch'] == 'keypoint':
res = keypoint_post_process(data, data_input, exe,
Expand All @@ -115,8 +116,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
metric.log()
map_res = metric.get_results()
metric.reset()
map_key = 'keypoint' if 'arch' in config and config[
'arch'] == 'keypoint' else 'bbox'
map_key = 'keypoint' if 'arch' in config and config['arch'] == 'keypoint' else 'bbox'
return map_res[map_key][0]


Expand All @@ -127,9 +127,8 @@ def main():
ptq_config = config['PTQ']

# val dataset is sufficient for PTQ
data_loader = create('EvalReader')(config['EvalDataset'],
config['worker_num'],
return_list=True)
data_loader = create('EvalReader')(
config['EvalDataset'], config['worker_num'], return_list=True)
ptq_data_loader = reader_wrapper(data_loader, config['input_list'])

# fast_val_anno_path, such as annotation path of several pictures can accelerate analysis
Expand All @@ -139,10 +138,11 @@ def main():
global val_loader
_eval_batch_sampler = paddle.io.BatchSampler(
dataset, batch_size=config['EvalReader']['batch_size'])
val_loader = create('EvalReader')(dataset,
config['worker_num'],
batch_sampler=_eval_batch_sampler,
return_list=True)
val_loader = create('EvalReader')(
dataset,
config['worker_num'],
batch_sampler=_eval_batch_sampler,
return_list=True)
global metric
if config['metric'] == 'COCO':
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
Expand All @@ -161,14 +161,14 @@ def main():
else:
raise ValueError("metric currently only supports COCO and VOC.")

analyzer = AnalysisPTQ(
model_dir=config["model_dir"],
analyzer = Analysis(
float_model_dir=config["model_dir"],
model_filename=config["model_filename"],
params_filename=config["params_filename"],
eval_function=eval_function,
data_loader=ptq_data_loader,
save_dir=config['save_dir'],
ptq_config=ptq_config,
quant_config=ptq_config,
resume=True, )

analyzer.statistical_analyse()
Expand Down
20 changes: 11 additions & 9 deletions example/post_training_quantization/pytorch_yolo_series/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from post_process import YOLOPostProcess, coco_metric
from dataset import COCOValDataset, COCOTrainDataset
from paddleslim.common import load_config, load_onnx_model
from paddleslim.quant.analysis_ptq import AnalysisPTQ
from paddleslim.quant.analysis import Analysis


def argsparser():
Expand All @@ -41,7 +41,8 @@ def argsparser():
'--resume',
type=bool,
default=False,
help="When break off while ananlyzing, could resume analysis program and load already analyzed information."
help=
"When break off while ananlyzing, could resume analysis program and load already analyzed information."
)
return parser

Expand All @@ -54,10 +55,11 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
ncols=80) as t:
for data in val_loader:
data_all = {k: np.array(v) for k, v in data.items()}
outs = exe.run(compiled_test_program,
feed={test_feed_names[0]: data_all['image']},
fetch_list=test_fetch_list,
return_numpy=False)
outs = exe.run(
compiled_test_program,
feed={test_feed_names[0]: data_all['image']},
fetch_list=test_fetch_list,
return_numpy=False)
res = {}
postprocess = YOLOPostProcess(
score_threshold=0.001, nms_threshold=0.65, multi_label=True)
Expand Down Expand Up @@ -103,15 +105,15 @@ def main():
load_onnx_model(config["model_dir"])
inference_model_path = config["model_dir"].rstrip().rstrip(
'.onnx') + '_infer'
analyzer = AnalysisPTQ(
model_dir=inference_model_path,
analyzer = Analysis(
float_model_dir=inference_model_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams',
eval_function=eval_function,
data_loader=data_loader,
save_dir=config['save_dir'],
resume=FLAGS.resume,
ptq_config=ptq_config)
quant_config=ptq_config)

analyzer.statistical_analyse()
analyzer.metric_error_analyse()
Expand Down
8 changes: 3 additions & 5 deletions example/quantization_analysis/GPT/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import paddle
from paddleslim.common import load_config as load_slim_config
from paddleslim.quant.analysis_qat import AnalysisQAT
from paddleslim.quant.analysis import Analysis
from ppfleetx.data import build_dataloader
from ppfleetx.distributed.apis import env
from utils import parse_config
Expand Down Expand Up @@ -164,17 +164,15 @@ def main():
global eval_loader
eval_loader = eval_reader_wrapper(valid_data_loader)

analyzer = AnalysisQAT(
analyzer = Analysis(
quant_model_dir=global_config["quant_model_dir"],
float_model_dir=global_config["float_model_dir"],
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"],
quantizable_op_type=global_config['quantizable_op_type'],
qat_metric=global_config['qat_metric']
if 'qat_metric' in global_config else None,
eval_function=eval_function,
data_loader=eval_loader,
save_dir=FLAGS.save_dir,
quant_config=all_config['quant_config'],
resume=global_config['resume'], )
analyzer.metric_error_analyse()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@ Global:
float_model_dir: ./GPT_345M_Baseline
model_filename: model.pdmodel
params_filename: model.pdiparams
quantizable_op_type: ["mul", "matmul", "matmul_v2"]
resume: False
reader_config: ./configs/gpt_reader.yaml
cloze_eval: True # True for LAMBADA Dataset; False for WikiText


quant_config:
quantizable_op_type: ["mul", "matmul", "matmul_v2"]
weight_quantize_type: 'abs_max'
activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: False
batch_size: 8
batch_nums: 10


Loading

0 comments on commit 3b2ed2c

Please sign in to comment.