diff --git a/README.md b/README.md index 283ee10e0..be5ab942c 100755 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ PaddleSlim是一个专注于深度学习模型压缩的工具库,提供**低 - 支持代码无感知压缩:用户只需提供推理模型文件和数据,既可进行离线量化(PTQ)、量化训练(QAT)、稀疏训练等压缩任务。 - 支持自动策略选择,根据任务特点和部署环境特性:自动搜索合适的离线量化方法,自动搜索最佳的压缩策略组合方式。 - 发布[自然语言处理](example/auto_compression/nlp)、[图像语义分割](example/auto_compression/semantic_segmentation)、[图像目标检测](example/auto_compression/detection)三个方向的自动化压缩示例。 - - 发布`X2Paddle`模型自动化压缩方案:[YOLOv5](example/auto_compression/pytorch_yolov5)、[HuggingFace](example/auto_compression/pytorch_huggingface) [MobileNet](example/auto_compression/tensorflow_mobilenet)。 + - 发布`X2Paddle`模型自动化压缩方案:[YOLOv5](example/auto_compression/pytorch_yolov5)、[YOLOv6](example/auto_compression/pytorch_yolov6)、[HuggingFace](example/auto_compression/pytorch_huggingface)、[MobileNet](example/auto_compression/tensorflow_mobilenet)。 - 升级量化功能 diff --git a/example/auto_compression/pytorch_yolov6/README.md b/example/auto_compression/pytorch_yolov6/README.md new file mode 100644 index 000000000..662778ff5 --- /dev/null +++ b/example/auto_compression/pytorch_yolov6/README.md @@ -0,0 +1,138 @@ +# YOLOv6自动压缩示例 + +目录: +- [1.简介](#1简介) +- [2.Benchmark](#2Benchmark) +- [3.开始自动压缩](#自动压缩流程) + - [3.1 环境准备](#31-准备环境) + - [3.2 准备数据集](#32-准备数据集) + - [3.3 准备预测模型](#33-准备预测模型) + - [3.4 测试模型精度](#34-测试模型精度) + - [3.5 自动压缩并产出模型](#35-自动压缩并产出模型) +- [4.预测部署](#4预测部署) +- [5.FAQ](5FAQ) + +## 1. 简介 + +飞桨模型转换工具[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)支持将```Caffe/TensorFlow/ONNX/PyTorch```的模型一键转为飞桨(PaddlePaddle)的预测模型。借助X2Paddle的能力,各种框架的推理模型可以很方便的使用PaddleSlim的自动化压缩功能。 + +本示例将以[meituan/YOLOv6](https://github.com/meituan/YOLOv6)目标检测模型为例,将PyTorch框架模型转换为Paddle框架模型,再使用ACT自动压缩功能进行自动压缩。本示例使用的自动压缩策略为量化训练。 + +## 2.Benchmark + +| 模型 | 策略 | 输入尺寸 | mAPval
0.5:0.95 | 预测时延FP32
(ms) |预测时延FP16
(ms) | 预测时延INT8
(ms) | 配置文件 | Inference模型 | +| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: | +| YOLOv6s | Base模型 | 640*640 | 42.4 | 9.06ms | 2.90ms | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov6s_infer.tar) | +| YOLOv6s | KL离线量化 | 640*640 | 30.3 | - | - | 1.83ms | - | - | +| YOLOv6s | 量化蒸馏训练 | 640*640 | **41.3** | - | - | **1.83ms** | [config](./configs/yolov6s_qat_dis.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov6s_quant.tar) | + +说明: +- mAP的指标均在COCO val2017数据集中评测得到。 +- YOLOv6s模型在Tesla T4的GPU环境下开启TensorRT 8.4.1,batch_size=1, 测试脚本是[cpp_infer](./cpp_infer)。 + +## 3. 自动压缩流程 + +#### 3.1 准备环境 +- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装) +- PaddleSlim > 2.3版本 +- PaddleDet >= 2.4 +- [X2Paddle](https://github.com/PaddlePaddle/X2Paddle) >= 1.3.6 +- opencv-python + +(1)安装paddlepaddle: +```shell +# CPU +pip install paddlepaddle +# GPU +pip install paddlepaddle-gpu +``` + +(2)安装paddleslim: +```shell +pip install paddleslim +``` + +(3)安装paddledet: +```shell +pip install paddledet +``` + +注:安装PaddleDet的目的只是为了直接使用PaddleDetection中的Dataloader组件。 + +(4)安装X2Paddle的1.3.6以上版本: +```shell +pip install x2paddle sympy onnx +``` + +#### 3.2 准备数据集 + +本案例默认以COCO数据进行自动压缩实验,并且依赖PaddleDetection中数据读取模块,如果自定义COCO数据,或者其他格式数据,请参考[PaddleDetection数据准备文档](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/docs/tutorials/PrepareDataSet.md) 来准备数据。 + +如果已经准备好数据集,请直接修改[./configs/yolov6_reader.yml]中`EvalDataset`的`dataset_dir`字段为自己数据集路径即可。 + + +#### 3.3 准备预测模型 + +(1)准备ONNX模型: + +可通过[meituan/YOLOv6](https://github.com/meituan/YOLOv6)官方的[导出教程](https://github.com/meituan/YOLOv6/blob/main/deploy/ONNX/README.md)来准备ONNX模型。也可以下载已经准备好的[yolov6s.onnx](https://paddle-slim-models.bj.bcebos.com/act/yolov6s.onnx)。 + + +(2) 转换模型: +``` +x2paddle --framework=onnx --model=yolov6s.onnx --save_dir=pd_model +cp -r pd_model/inference_model/ yolov6s_infer +``` +即可得到YOLOv6s模型的预测模型(`model.pdmodel` 和 `model.pdiparams`)。如想快速体验,可直接下载上方表格中YOLOv6s的[Paddle预测模型](https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov6s_infer.tar)。 + + +预测模型的格式为:`model.pdmodel` 和 `model.pdiparams`两个,带`pdmodel`的是模型文件,带`pdiparams`后缀的是权重文件。 + + +#### 3.4 自动压缩并产出模型 + +蒸馏量化自动压缩示例通过run.py脚本启动,会使用接口```paddleslim.auto_compression.AutoCompression```对模型进行自动压缩。配置config文件中模型路径、蒸馏、量化、和训练等部分的参数,配置完成后便可对模型进行量化和蒸馏。具体运行命令为: + +- 单卡训练: +``` +export CUDA_VISIBLE_DEVICES=0 +python run.py --config_path=./configs/yolov6s_qat_dis.yaml --save_dir='./output/' +``` + +- 多卡训练: +``` +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \ + --config_path=./configs/yolov6s_qat_dis.yaml --save_dir='./output/' +``` + +#### 3.5 测试模型精度 + +修改[yolov6s_qat_dis.yaml](./configs/yolov6s_qat_dis.yaml)中`model_dir`字段为模型存储路径,然后使用eval.py脚本得到模型的mAP: +``` +export CUDA_VISIBLE_DEVICES=0 +python eval.py --config_path=./configs/yolov6s_qat_dis.yaml +``` + + +## 4.预测部署 + +#### Paddle-TensorRT C++部署 + +进入[cpp_infer](./cpp_infer)文件夹内,请按照[C++ TensorRT Benchmark测试教程](./cpp_infer/README.md)进行准备环境及编译,然后开始测试: +```shell +# 编译 +bash complie.sh +# 执行 +./build/trt_run --model_file yolov6s_quant/model.pdmodel --params_file yolov6s_quant/model.pdiparams --run_mode=trt_int8 +``` + +#### Paddle-TensorRT Python部署: + +首先安装带有TensorRT的[Paddle安装包](https://www.paddlepaddle.org.cn/inference/v2.3/user_guides/download_lib.html#python)。 + +然后使用[paddle_trt_infer.py](./paddle_trt_infer.py)进行部署: +```shell +python paddle_trt_infer.py --model_path=output --image_file=images/000000570688.jpg --benchmark=True --run_mode=trt_int8 +``` + +## 5.FAQ diff --git a/example/auto_compression/pytorch_yolov6/configs/yolov6_reader.yml b/example/auto_compression/pytorch_yolov6/configs/yolov6_reader.yml new file mode 100644 index 000000000..cb87c3f8f --- /dev/null +++ b/example/auto_compression/pytorch_yolov6/configs/yolov6_reader.yml @@ -0,0 +1,27 @@ +metric: COCO +num_classes: 80 + +# Datset configuration +TrainDataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: dataset/coco/ + +EvalDataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco/ + +worker_num: 0 + +# preprocess reader in test +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [640, 640], keep_ratio: True} + - Pad: {size: [640, 640], fill_value: [114., 114., 114.]} + - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True} + - Permute: {} + batch_size: 1 diff --git a/example/auto_compression/pytorch_yolov6/configs/yolov6s_qat_dis.yaml b/example/auto_compression/pytorch_yolov6/configs/yolov6s_qat_dis.yaml new file mode 100644 index 000000000..4fcf4777f --- /dev/null +++ b/example/auto_compression/pytorch_yolov6/configs/yolov6s_qat_dis.yaml @@ -0,0 +1,31 @@ + +Global: + reader_config: configs/yolov6_reader.yml + input_list: {'image': 'x2paddle_image_arrays'} + Evaluation: True + arch: 'YOLOv6' + model_dir: ./yolov6s_infer + model_filename: model.pdmodel + params_filename: model.pdiparams + +Distillation: + alpha: 1.0 + loss: soft_label + +Quantization: + activation_quantize_type: 'moving_average_abs_max' + quantize_op_types: + - conv2d + - depthwise_conv2d + +TrainConfig: + train_iter: 8000 + eval_iter: 1000 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.00003 + T_max: 8000 + optimizer_builder: + optimizer: + type: SGD + weight_decay: 0.00004 diff --git a/example/auto_compression/pytorch_yolov6/cpp_infer/CMakeLists.txt b/example/auto_compression/pytorch_yolov6/cpp_infer/CMakeLists.txt new file mode 100644 index 000000000..d5307c657 --- /dev/null +++ b/example/auto_compression/pytorch_yolov6/cpp_infer/CMakeLists.txt @@ -0,0 +1,263 @@ +cmake_minimum_required(VERSION 3.0) +project(cpp_inference_demo CXX C) +option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON) +option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF) +option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON) +option(USE_TENSORRT "Compile demo with TensorRT." OFF) +option(WITH_ROCM "Compile demo with rocm." OFF) +option(WITH_ONNXRUNTIME "Compile demo with ONNXRuntime" OFF) +option(WITH_ARM "Compile demo with ARM" OFF) +option(WITH_MIPS "Compile demo with MIPS" OFF) +option(WITH_SW "Compile demo with SW" OFF) +option(WITH_XPU "Compile demow ith xpu" OFF) +option(WITH_NPU "Compile demow ith npu" OFF) + +if(NOT WITH_STATIC_LIB) + add_definitions("-DPADDLE_WITH_SHARED_LIB") +else() + # PD_INFER_DECL is mainly used to set the dllimport/dllexport attribute in dynamic library mode. + # Set it to empty in static library mode to avoid compilation issues. + add_definitions("/DPD_INFER_DECL=") +endif() + +macro(safe_set_static_flag) + foreach(flag_var + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) + if(${flag_var} MATCHES "/MD") + string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") + endif(${flag_var} MATCHES "/MD") + endforeach(flag_var) +endmacro() + +if(NOT DEFINED PADDLE_LIB) + message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib") +endif() +if(NOT DEFINED DEMO_NAME) + message(FATAL_ERROR "please set DEMO_NAME with -DDEMO_NAME=demo_name") +endif() + +include_directories("${PADDLE_LIB}/") +set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/include") +include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/include") + +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib") +link_directories("${PADDLE_LIB}/paddle/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib") +link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/lib") + +if (WIN32) + add_definitions("/DGOOGLE_GLOG_DLL_DECL=") + option(MSVC_STATIC_CRT "use static C Runtime library by default" ON) + if (MSVC_STATIC_CRT) + if (WITH_MKL) + set(FLAG_OPENMP "/openmp") + endif() + set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}") + set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}") + safe_set_static_flag() + if (WITH_STATIC_LIB) + add_definitions(-DSTATIC_LIB) + endif() + endif() +else() + if(WITH_MKL) + set(FLAG_OPENMP "-fopenmp") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 ${FLAG_OPENMP}") +endif() + +if(WITH_GPU) + if(NOT WIN32) + include_directories("/usr/local/cuda/include") + if(CUDA_LIB STREQUAL "") + set(CUDA_LIB "/usr/local/cuda/lib64/" CACHE STRING "CUDA Library") + endif() + else() + include_directories("C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\include") + if(CUDA_LIB STREQUAL "") + set(CUDA_LIB "C:\\Program\ Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v8.0\\lib\\x64") + endif() + endif(NOT WIN32) +endif() + +if (USE_TENSORRT AND WITH_GPU) + set(TENSORRT_ROOT "" CACHE STRING "The root directory of TensorRT library") + if("${TENSORRT_ROOT}" STREQUAL "") + message(FATAL_ERROR "The TENSORRT_ROOT is empty, you must assign it a value with CMake command. Such as: -DTENSORRT_ROOT=TENSORRT_ROOT_PATH ") + endif() + set(TENSORRT_INCLUDE_DIR ${TENSORRT_ROOT}/include) + set(TENSORRT_LIB_DIR ${TENSORRT_ROOT}/lib) + file(READ ${TENSORRT_INCLUDE_DIR}/NvInfer.h TENSORRT_VERSION_FILE_CONTENTS) + string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION + "${TENSORRT_VERSION_FILE_CONTENTS}") + if("${TENSORRT_MAJOR_VERSION}" STREQUAL "") + file(READ ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h TENSORRT_VERSION_FILE_CONTENTS) + string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION + "${TENSORRT_VERSION_FILE_CONTENTS}") + endif() + if("${TENSORRT_MAJOR_VERSION}" STREQUAL "") + message(SEND_ERROR "Failed to detect TensorRT version.") + endif() + string(REGEX REPLACE "define NV_TENSORRT_MAJOR +([0-9]+)" "\\1" + TENSORRT_MAJOR_VERSION "${TENSORRT_MAJOR_VERSION}") + message(STATUS "Current TensorRT header is ${TENSORRT_INCLUDE_DIR}/NvInfer.h. " + "Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ") + include_directories("${TENSORRT_INCLUDE_DIR}") + link_directories("${TENSORRT_LIB_DIR}") +endif() + +if(WITH_MKL) + set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml") + include_directories("${MATH_LIB_PATH}/include") + if(WIN32) + set(MATH_LIB ${MATH_LIB_PATH}/lib/mklml${CMAKE_STATIC_LIBRARY_SUFFIX} + ${MATH_LIB_PATH}/lib/libiomp5md${CMAKE_STATIC_LIBRARY_SUFFIX}) + else() + set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} + ${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) + endif() + set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn") + if(EXISTS ${MKLDNN_PATH}) + include_directories("${MKLDNN_PATH}/include") + if(WIN32) + set(MKLDNN_LIB ${MKLDNN_PATH}/lib/mkldnn.lib) + else(WIN32) + set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) + endif(WIN32) + endif() +elseif((NOT WITH_MIPS) AND (NOT WITH_SW)) + set(OPENBLAS_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}openblas") + include_directories("${OPENBLAS_LIB_PATH}/include/openblas") + if(WIN32) + set(MATH_LIB ${OPENBLAS_LIB_PATH}/lib/openblas${CMAKE_STATIC_LIBRARY_SUFFIX}) + else() + set(MATH_LIB ${OPENBLAS_LIB_PATH}/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX}) + endif() +endif() + +if(WITH_STATIC_LIB) + set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX}) +else() + if(WIN32) + set(DEPS ${PADDLE_LIB}/paddle/lib/paddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX}) + else() + set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX}) + endif() +endif() + +if (WITH_ONNXRUNTIME) + if(WIN32) + set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/onnxruntime.lib paddle2onnx) + elseif(APPLE) + set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/libonnxruntime.1.10.0.dylib paddle2onnx) + else() + set(DEPS ${DEPS} ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/libonnxruntime.so.1.10.0 paddle2onnx) + endif() +endif() + +if (NOT WIN32) + set(EXTERNAL_LIB "-lrt -ldl -lpthread") + set(DEPS ${DEPS} + ${MATH_LIB} ${MKLDNN_LIB} + glog gflags protobuf xxhash cryptopp + ${EXTERNAL_LIB}) +else() + set(DEPS ${DEPS} + ${MATH_LIB} ${MKLDNN_LIB} + glog gflags_static libprotobuf xxhash cryptopp-static ${EXTERNAL_LIB}) + set(DEPS ${DEPS} shlwapi.lib) +endif(NOT WIN32) + +if(WITH_GPU) + if(NOT WIN32) + if (USE_TENSORRT) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX}) + endif() + set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) + else() + if(USE_TENSORRT) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX}) + if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/myelin64_1${CMAKE_STATIC_LIBRARY_SUFFIX}) + endif() + endif() + set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX} ) + endif() +endif() + +if(WITH_ROCM AND NOT WIN32) + set(DEPS ${DEPS} ${ROCM_LIB}/libamdhip64${CMAKE_SHARED_LIBRARY_SUFFIX}) +endif() + +if(WITH_XPU AND NOT WIN32) + set(XPU_INSTALL_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}xpu") + set(DEPS ${DEPS} ${XPU_INSTALL_PATH}/lib/libxpuapi${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${XPU_INSTALL_PATH}/lib/libxpurt${CMAKE_SHARED_LIBRARY_SUFFIX}) +endif() + +if(WITH_NPU AND NOT WIN32) + set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libgraph${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libge_runner${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libascendcl${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libascendcl${CMAKE_SHARED_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64/libacl_op_compiler${CMAKE_SHARED_LIBRARY_SUFFIX}) +endif() + +add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) +target_link_libraries(${DEMO_NAME} ${DEPS}) +if(WIN32) + if(USE_TENSORRT) + add_custom_command(TARGET ${DEMO_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/nvinfer${CMAKE_SHARED_LIBRARY_SUFFIX} + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} + COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/nvinfer_plugin${CMAKE_SHARED_LIBRARY_SUFFIX} + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} + ) + if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) + add_custom_command(TARGET ${DEMO_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${TENSORRT_LIB_DIR}/myelin64_1${CMAKE_SHARED_LIBRARY_SUFFIX} + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}) + endif() + endif() + if(WITH_MKL) + add_custom_command(TARGET ${DEMO_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/mklml.dll ${CMAKE_BINARY_DIR}/Release + COMMAND ${CMAKE_COMMAND} -E copy ${MATH_LIB_PATH}/lib/libiomp5md.dll ${CMAKE_BINARY_DIR}/Release + COMMAND ${CMAKE_COMMAND} -E copy ${MKLDNN_PATH}/lib/mkldnn.dll ${CMAKE_BINARY_DIR}/Release + ) + else() + add_custom_command(TARGET ${DEMO_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${OPENBLAS_LIB_PATH}/lib/openblas.dll ${CMAKE_BINARY_DIR}/Release + ) + endif() + if(WITH_ONNXRUNTIME) + add_custom_command(TARGET ${DEMO_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_LIB_THIRD_PARTY_PATH}onnxruntime/lib/onnxruntime.dll + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} + COMMAND ${CMAKE_COMMAND} -E copy ${PADDLE_LIB_THIRD_PARTY_PATH}paddle2onnx/lib/paddle2onnx.dll + ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} + ) + endif() + if(NOT WITH_STATIC_LIB) + add_custom_command(TARGET ${DEMO_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy "${PADDLE_LIB}/paddle/lib/paddle_inference.dll" ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE} + ) + endif() +endif() diff --git a/example/auto_compression/pytorch_yolov6/cpp_infer/README.md b/example/auto_compression/pytorch_yolov6/cpp_infer/README.md new file mode 100644 index 000000000..2f2204862 --- /dev/null +++ b/example/auto_compression/pytorch_yolov6/cpp_infer/README.md @@ -0,0 +1,50 @@ +# YOLOv6 TensorRT Benchmark测试(Linux) + +## 环境准备 + +- CUDA、CUDNN:确认环境中已经安装CUDA和CUDNN,并且提前获取其安装路径。 + +- TensorRT:可通过NVIDIA官网下载[TensorRT 8.4.1.5](https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.4.1/tars/tensorrt-8.4.1.5.linux.x86_64-gnu.cuda-11.6.cudnn8.4.tar.gz)或其他版本安装包。 + +- Paddle Inference C++预测库:编译develop版本请参考[编译文档](https://www.paddlepaddle.org.cn/inference/user_guides/source_compile.html)。编译完成后,会在build目录下生成`paddle_inference_install_dir`文件夹,这个就是我们需要的C++预测库文件。 + +## 编译可执行程序 + +- (1)修改`compile.sh`中依赖库路径,主要是以下内容: +```shell +# Paddle Inference预测库路径 +LIB_DIR=/root/auto_compress/Paddle/build/paddle_inference_install_dir/ +# CUDNN路径 +CUDNN_LIB=/usr/lib/x86_64-linux-gnu/ +# CUDA路径 +CUDA_LIB=/usr/local/cuda/lib64 +# TensorRT安装包路径,为TRT资源包解压完成后的绝对路径,其中包含`lib`和`include`文件夹 +TENSORRT_ROOT=/root/auto_compress/trt/trt8.4/ +``` + +## 测试 + +- FP32 +``` +./build/trt_run --model_file yolov6s_infer/model.pdmodel --params_file yolov6s_infer/model.pdiparams --run_mode=trt_fp32 +``` + +- FP16 +``` +./build/trt_run --model_file yolov6s_infer/model.pdmodel --params_file yolov6s_infer/model.pdiparams --run_mode=trt_fp16 +``` + +- INT8 +``` +./build/trt_run --model_file yolov6s_quant/model.pdmodel --params_file yolov6s_quant/model.pdiparams --run_mode=trt_int8 +``` + +## 性能对比 + +| 模型 | 预测时延FP32
(ms) |预测时延FP16
(ms) | 预测时延INT8
(ms) | +| :-------- |:-------- |:--------: | :---------------------: | +| YOLOv6s | 9.06ms | 2.90ms | 1.83ms | + +环境: +- Tesla T4,TensorRT 8.4.1,CUDA 11.2 +- batch_size=1 diff --git a/example/auto_compression/pytorch_yolov6/cpp_infer/compile.sh b/example/auto_compression/pytorch_yolov6/cpp_infer/compile.sh new file mode 100644 index 000000000..afff924b4 --- /dev/null +++ b/example/auto_compression/pytorch_yolov6/cpp_infer/compile.sh @@ -0,0 +1,37 @@ +#!/bin/bash +set +x +set -e + +work_path=$(dirname $(readlink -f $0)) + +mkdir -p build +cd build +rm -rf * + +DEMO_NAME=trt_run + +WITH_MKL=ON +WITH_GPU=ON +USE_TENSORRT=ON + +LIB_DIR=/root/auto_compress/Paddle/build/paddle_inference_install_dir/ +CUDNN_LIB=/usr/lib/x86_64-linux-gnu/ +CUDA_LIB=/usr/local/cuda/lib64 +TENSORRT_ROOT=/root/auto_compress/trt/trt8.4/ + +WITH_ROCM=OFF +ROCM_LIB=/opt/rocm/lib + +cmake .. -DPADDLE_LIB=${LIB_DIR} \ + -DWITH_MKL=${WITH_MKL} \ + -DDEMO_NAME=${DEMO_NAME} \ + -DWITH_GPU=${WITH_GPU} \ + -DWITH_STATIC_LIB=OFF \ + -DUSE_TENSORRT=${USE_TENSORRT} \ + -DWITH_ROCM=${WITH_ROCM} \ + -DROCM_LIB=${ROCM_LIB} \ + -DCUDNN_LIB=${CUDNN_LIB} \ + -DCUDA_LIB=${CUDA_LIB} \ + -DTENSORRT_ROOT=${TENSORRT_ROOT} + +make -j diff --git a/example/auto_compression/pytorch_yolov6/cpp_infer/trt_run.cc b/example/auto_compression/pytorch_yolov6/cpp_infer/trt_run.cc new file mode 100644 index 000000000..9c14baf7d --- /dev/null +++ b/example/auto_compression/pytorch_yolov6/cpp_infer/trt_run.cc @@ -0,0 +1,116 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "paddle/include/paddle_inference_api.h" +#include "paddle/include/experimental/phi/common/float16.h" + +using paddle_infer::Config; +using paddle_infer::Predictor; +using paddle_infer::CreatePredictor; +using paddle_infer::PrecisionType; +using phi::dtype::float16; + +DEFINE_string(model_dir, "", "Directory of the inference model."); +DEFINE_string(model_file, "", "Path of the inference model file."); +DEFINE_string(params_file, "", "Path of the inference params file."); +DEFINE_string(run_mode, "trt_fp32", "run_mode which can be: trt_fp32, trt_fp16 and trt_int8"); +DEFINE_int32(batch_size, 1, "Batch size."); +DEFINE_int32(gpu_id, 0, "GPU card ID num."); +DEFINE_int32(trt_min_subgraph_size, 3, "tensorrt min_subgraph_size"); +DEFINE_int32(warmup, 50, "warmup"); +DEFINE_int32(repeats, 1000, "repeats"); + +using Time = decltype(std::chrono::high_resolution_clock::now()); +Time time() { return std::chrono::high_resolution_clock::now(); }; +double time_diff(Time t1, Time t2) { + typedef std::chrono::microseconds ms; + auto diff = t2 - t1; + ms counter = std::chrono::duration_cast(diff); + return counter.count() / 1000.0; +} + +std::shared_ptr InitPredictor() { + Config config; + std::string model_path; + if (FLAGS_model_dir != "") { + config.SetModel(FLAGS_model_dir); + model_path = FLAGS_model_dir.substr(0, FLAGS_model_dir.find_last_of("/")); + } else { + config.SetModel(FLAGS_model_file, FLAGS_params_file); + model_path = FLAGS_model_file.substr(0, FLAGS_model_file.find_last_of("/")); + } + // enable tune + std::cout << "model_path: " << model_path << std::endl; + config.EnableUseGpu(256, FLAGS_gpu_id); + if (FLAGS_run_mode == "trt_fp32") { + config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size, + PrecisionType::kFloat32, false, false); + } else if (FLAGS_run_mode == "trt_fp16") { + config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size, + PrecisionType::kHalf, false, false); + } else if (FLAGS_run_mode == "trt_int8") { + config.EnableTensorRtEngine(1 << 30, FLAGS_batch_size, FLAGS_trt_min_subgraph_size, + PrecisionType::kInt8, false, false); + } + config.EnableMemoryOptim(); + config.SwitchIrOptim(true); + return CreatePredictor(config); +} + +template +void run(Predictor *predictor, const std::vector &input, + const std::vector &input_shape, type* out_data, std::vector out_shape) { + + // prepare input + int input_num = std::accumulate(input_shape.begin(), input_shape.end(), 1, + std::multiplies()); + + auto input_names = predictor->GetInputNames(); + auto input_t = predictor->GetInputHandle(input_names[0]); + input_t->Reshape(input_shape); + input_t->CopyFromCpu(input.data()); + + for (int i = 0; i < FLAGS_warmup; ++i) + CHECK(predictor->Run()); + + auto st = time(); + for (int i = 0; i < FLAGS_repeats; ++i) { + auto input_names = predictor->GetInputNames(); + auto input_t = predictor->GetInputHandle(input_names[0]); + input_t->Reshape(input_shape); + input_t->CopyFromCpu(input.data()); + + CHECK(predictor->Run()); + + auto output_names = predictor->GetOutputNames(); + auto output_t = predictor->GetOutputHandle(output_names[0]); + std::vector output_shape = output_t->shape(); + output_t -> ShareExternalData(out_data, out_shape, paddle_infer::PlaceType::kGPU); + } + + LOG(INFO) << "[" << FLAGS_run_mode << " bs-" << FLAGS_batch_size << " ] run avg time is " << time_diff(st, time()) / FLAGS_repeats + << " ms"; +} + +int main(int argc, char *argv[]) { + google::ParseCommandLineFlags(&argc, &argv, true); + auto predictor = InitPredictor(); + std::vector input_shape = {FLAGS_batch_size, 3, 640, 640}; + // float16 + using dtype = float16; + std::vector input_data(FLAGS_batch_size * 3 * 640 * 640, dtype(1.0)); + + dtype *out_data; + int out_data_size = FLAGS_batch_size * 8400 * 85; + cudaHostAlloc((void**)&out_data, sizeof(float) * out_data_size, cudaHostAllocMapped); + + std::vector out_shape{ FLAGS_batch_size, 1, 8400, 85}; + run(predictor.get(), input_data, input_shape, out_data, out_shape); + return 0; +} diff --git a/example/auto_compression/pytorch_yolov6/eval.py b/example/auto_compression/pytorch_yolov6/eval.py new file mode 100644 index 000000000..62127b512 --- /dev/null +++ b/example/auto_compression/pytorch_yolov6/eval.py @@ -0,0 +1,159 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import argparse +import paddle +from ppdet.core.workspace import load_config, merge_config +from ppdet.core.workspace import create +from ppdet.metrics import COCOMetric, VOCMetric +from paddleslim.auto_compression.config_helpers import load_config as load_slim_config + +from post_process import YOLOv6PostProcess + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--devices', + type=str, + default='gpu', + help="which device used to compress.") + + return parser + + +def reader_wrapper(reader, input_list): + def gen(): + for data in reader: + in_dict = {} + if isinstance(input_list, list): + for input_name in input_list: + in_dict[input_name] = data[input_name] + elif isinstance(input_list, dict): + for input_name in input_list.keys(): + in_dict[input_list[input_name]] = data[input_name] + yield in_dict + + return gen + + +def convert_numpy_data(data, metric): + data_all = {} + data_all = {k: np.array(v) for k, v in data.items()} + if isinstance(metric, VOCMetric): + for k, v in data_all.items(): + if not isinstance(v[0], np.ndarray): + tmp_list = [] + for t in v: + tmp_list.append(np.array(t)) + data_all[k] = np.array(tmp_list) + else: + data_all = {k: np.array(v) for k, v in data.items()} + return data_all + + +def eval(): + + place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() + exe = paddle.static.Executor(place) + + val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( + global_config["model_dir"], + exe, + model_filename=global_config["model_filename"], + params_filename=global_config["params_filename"]) + print('Loaded model from: {}'.format(global_config["model_dir"])) + + metric = global_config['metric'] + for batch_id, data in enumerate(val_loader): + data_all = convert_numpy_data(data, metric) + data_input = {} + for k, v in data.items(): + if isinstance(global_config['input_list'], list): + if k in global_config['input_list']: + data_input[k] = np.array(v) + elif isinstance(global_config['input_list'], dict): + if k in global_config['input_list'].keys(): + data_input[global_config['input_list'][k]] = np.array(v) + outs = exe.run(val_program, + feed=data_input, + fetch_list=fetch_targets, + return_numpy=False) + res = {} + if 'arch' in global_config and global_config['arch'] == 'YOLOv6': + postprocess = YOLOv6PostProcess( + score_threshold=0.001, nms_threshold=0.65, multi_label=True) + res = postprocess(np.array(outs[0]), data_all['scale_factor']) + else: + for out in outs: + v = np.array(out) + if len(v.shape) > 1: + res['bbox'] = v + else: + res['bbox_num'] = v + metric.update(data_all, res) + if batch_id % 100 == 0: + print('Eval iter:', batch_id) + metric.accumulate() + metric.log() + metric.reset() + + +def main(): + global global_config + all_config = load_slim_config(FLAGS.config_path) + global_config = all_config["Global"] + reader_cfg = load_config(global_config['reader_config']) + + dataset = reader_cfg['EvalDataset'] + global val_loader + val_loader = create('EvalReader')(reader_cfg['EvalDataset'], + reader_cfg['worker_num'], + return_list=True) + metric = None + if reader_cfg['metric'] == 'COCO': + clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} + anno_file = dataset.get_anno() + metric = COCOMetric( + anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox') + elif reader_cfg['metric'] == 'VOC': + metric = VOCMetric( + label_list=dataset.get_label_list(), + class_num=reader_cfg['num_classes'], + map_type=reader_cfg['map_type']) + else: + raise ValueError("metric currently only supports COCO and VOC.") + global_config['metric'] = metric + + eval() + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + + assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu'] + paddle.set_device(FLAGS.devices) + + main() diff --git a/example/auto_compression/pytorch_yolov6/images/000000570688.jpg b/example/auto_compression/pytorch_yolov6/images/000000570688.jpg new file mode 100644 index 000000000..cb304bd56 Binary files /dev/null and b/example/auto_compression/pytorch_yolov6/images/000000570688.jpg differ diff --git a/example/auto_compression/pytorch_yolov6/paddle_trt_infer.py b/example/auto_compression/pytorch_yolov6/paddle_trt_infer.py new file mode 100644 index 000000000..5d88643f4 --- /dev/null +++ b/example/auto_compression/pytorch_yolov6/paddle_trt_infer.py @@ -0,0 +1,322 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import numpy as np +import argparse +import time + +from paddle.inference import Config +from paddle.inference import create_predictor + +from post_process import YOLOv6PostProcess + +CLASS_LABEL = [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', + 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', + 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', + 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', + 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', + 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush' +] + + +def generate_scale(im, target_shape, keep_ratio=True): + """ + Args: + im (np.ndarray): image (np.ndarray) + Returns: + im_scale_x: the resize ratio of X + im_scale_y: the resize ratio of Y + """ + origin_shape = im.shape[:2] + if keep_ratio: + im_size_min = np.min(origin_shape) + im_size_max = np.max(origin_shape) + target_size_min = np.min(target_shape) + target_size_max = np.max(target_shape) + im_scale = float(target_size_min) / float(im_size_min) + if np.round(im_scale * im_size_max) > target_size_max: + im_scale = float(target_size_max) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + else: + resize_h, resize_w = target_shape + im_scale_y = resize_h / float(origin_shape[0]) + im_scale_x = resize_w / float(origin_shape[1]) + return im_scale_y, im_scale_x + + +def image_preprocess(img_path, target_shape): + img = cv2.imread(img_path) + # Resize + im_scale_y, im_scale_x = generate_scale(img, target_shape) + img = cv2.resize( + img, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=cv2.INTER_LINEAR) + # Pad + im_h, im_w = img.shape[:2] + h, w = target_shape[:] + if h != im_h or w != im_w: + canvas = np.ones((h, w, 3), dtype=np.float32) + canvas *= np.array([114.0, 114.0, 114.0], dtype=np.float32) + canvas[0:im_h, 0:im_w, :] = img.astype(np.float32) + img = canvas + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = np.transpose(img, [2, 0, 1]) / 255 + img = np.expand_dims(img, 0) + scale_factor = np.array([[im_scale_y, im_scale_x]]) + return img.astype(np.float32), scale_factor + + +def get_color_map_list(num_classes): + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) + color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) + color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + j += 1 + lab >>= 3 + color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] + return color_map + + +def draw_box(image_file, results, class_label, threshold=0.5): + srcimg = cv2.imread(image_file, 1) + for i in range(len(results)): + color_list = get_color_map_list(len(class_label)) + clsid2color = {} + classid, conf = int(results[i, 0]), results[i, 1] + if conf < threshold: + continue + xmin, ymin, xmax, ymax = int(results[i, 2]), int(results[i, 3]), int( + results[i, 4]), int(results[i, 5]) + + if classid not in clsid2color: + clsid2color[classid] = color_list[classid] + color = tuple(clsid2color[classid]) + + cv2.rectangle(srcimg, (xmin, ymin), (xmax, ymax), color, thickness=2) + print(class_label[classid] + ': ' + str(round(conf, 3))) + cv2.putText( + srcimg, + class_label[classid] + ':' + str(round(conf, 3)), (xmin, ymin - 10), + cv2.FONT_HERSHEY_SIMPLEX, + 0.8, (0, 255, 0), + thickness=2) + return srcimg + + +def load_predictor(model_dir, + run_mode='paddle', + batch_size=1, + device='CPU', + min_subgraph_size=3, + use_dynamic_shape=False, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False, + enable_mkldnn_bfloat16=False, + delete_shuffle_pass=False): + """set AnalysisConfig, generate AnalysisPredictor + Args: + model_dir (str): root path of __model__ and __params__ + device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU + run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8) + use_dynamic_shape (bool): use dynamic shape or not + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT. + Used by action model. + Returns: + predictor (PaddlePredictor): AnalysisPredictor + Raises: + ValueError: predict by TensorRT need device == 'GPU'. + """ + if device != 'GPU' and run_mode != 'paddle': + raise ValueError( + "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}" + .format(run_mode, device)) + config = Config( + os.path.join(model_dir, 'model.pdmodel'), + os.path.join(model_dir, 'model.pdiparams')) + if device == 'GPU': + # initial GPU memory(M), device ID + config.enable_use_gpu(200, 0) + # optimize graph and fuse op + config.switch_ir_optim(True) + elif device == 'XPU': + config.enable_lite_engine() + config.enable_xpu(10 * 1024 * 1024) + else: + config.disable_gpu() + config.set_cpu_math_library_num_threads(cpu_threads) + if enable_mkldnn: + try: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() + if enable_mkldnn_bfloat16: + config.enable_mkldnn_bfloat16() + except Exception as e: + print( + "The current environment does not support `mkldnn`, so disable mkldnn." + ) + pass + + precision_map = { + 'trt_int8': Config.Precision.Int8, + 'trt_fp32': Config.Precision.Float32, + 'trt_fp16': Config.Precision.Half + } + if run_mode in precision_map.keys(): + config.enable_tensorrt_engine( + workspace_size=(1 << 25) * batch_size, + max_batch_size=batch_size, + min_subgraph_size=min_subgraph_size, + precision_mode=precision_map[run_mode], + use_static=False, + use_calib_mode=trt_calib_mode) + + if use_dynamic_shape: + min_input_shape = { + 'image': [batch_size, 3, trt_min_shape, trt_min_shape] + } + max_input_shape = { + 'image': [batch_size, 3, trt_max_shape, trt_max_shape] + } + opt_input_shape = { + 'image': [batch_size, 3, trt_opt_shape, trt_opt_shape] + } + config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, + opt_input_shape) + print('trt set dynamic shape done!') + + # disable print log when predict + config.disable_glog_info() + # enable shared memory + config.enable_memory_optim() + # disable feed, fetch OP, needed by zero_copy_run + config.switch_use_feed_fetch_ops(False) + if delete_shuffle_pass: + config.delete_pass("shuffle_channel_detect_pass") + predictor = create_predictor(config) + return predictor + + +def predict_image(predictor, + image_file, + image_shape=[640, 640], + warmup=1, + repeats=1, + threshold=0.5, + arch='YOLOv5'): + img, scale_factor = image_preprocess(image_file, image_shape) + inputs = {} + if arch == 'YOLOv5': + inputs['x2paddle_images'] = img + input_names = predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(inputs[input_names[i]]) + + for i in range(warmup): + predictor.run() + + np_boxes = None + predict_time = 0. + time_min = float("inf") + time_max = float('-inf') + for i in range(repeats): + start_time = time.time() + predictor.run() + output_names = predictor.get_output_names() + boxes_tensor = predictor.get_output_handle(output_names[0]) + np_boxes = boxes_tensor.copy_to_cpu() + end_time = time.time() + timed = end_time - start_time + time_min = min(time_min, timed) + time_max = max(time_max, timed) + predict_time += timed + + time_avg = predict_time / repeats + print('Inference time(ms): min={}, max={}, avg={}'.format( + round(time_min * 1000, 2), + round(time_max * 1000, 1), round(time_avg * 1000, 1))) + postprocess = YOLOv6PostProcess( + score_threshold=0.001, nms_threshold=0.65, multi_label=True) + res = postprocess(np_boxes, scale_factor) + res_img = draw_box( + image_file, res['bbox'], CLASS_LABEL, threshold=threshold) + cv2.imwrite('result.jpg', res_img) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument( + '--image_file', type=str, default=None, help="image path") + parser.add_argument( + '--model_path', type=str, help="inference model filepath") + parser.add_argument( + '--benchmark', + type=bool, + default=False, + help="Whether run benchmark or not.") + parser.add_argument( + '--run_mode', + type=str, + default='paddle', + help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)") + parser.add_argument( + '--device', + type=str, + default='GPU', + help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is GPU" + ) + parser.add_argument('--img_shape', type=int, default=640, help="input_size") + args = parser.parse_args() + + predictor = load_predictor( + args.model_path, run_mode=args.run_mode, device=args.device) + warmup, repeats = 1, 1 + if args.benchmark: + warmup, repeats = 50, 100 + predict_image( + predictor, + args.image_file, + image_shape=[args.img_shape, args.img_shape], + warmup=warmup, + repeats=repeats) diff --git a/example/auto_compression/pytorch_yolov6/post_process.py b/example/auto_compression/pytorch_yolov6/post_process.py new file mode 100644 index 000000000..37bd2c959 --- /dev/null +++ b/example/auto_compression/pytorch_yolov6/post_process.py @@ -0,0 +1,173 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import cv2 + + +def box_area(boxes): + """ + Args: + boxes(np.ndarray): [N, 4] + return: [N] + """ + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +def box_iou(box1, box2): + """ + Args: + box1(np.ndarray): [N, 4] + box2(np.ndarray): [M, 4] + return: [N, M] + """ + area1 = box_area(box1) + area2 = box_area(box2) + lt = np.maximum(box1[:, np.newaxis, :2], box2[:, :2]) + rb = np.minimum(box1[:, np.newaxis, 2:], box2[:, 2:]) + wh = rb - lt + wh = np.maximum(0, wh) + inter = wh[:, :, 0] * wh[:, :, 1] + iou = inter / (area1[:, np.newaxis] + area2 - inter) + return iou + + +def nms(boxes, scores, iou_threshold): + """ + Non Max Suppression numpy implementation. + args: + boxes(np.ndarray): [N, 4] + scores(np.ndarray): [N, 1] + iou_threshold(float): Threshold of IoU. + """ + idxs = scores.argsort() + keep = [] + while idxs.size > 0: + max_score_index = idxs[-1] + max_score_box = boxes[max_score_index][None, :] + keep.append(max_score_index) + if idxs.size == 1: + break + idxs = idxs[:-1] + other_boxes = boxes[idxs] + ious = box_iou(max_score_box, other_boxes) + idxs = idxs[ious[0] <= iou_threshold] + + keep = np.array(keep) + return keep + + +class YOLOv6PostProcess(object): + """ + Post process of YOLOv6 network. + args: + score_threshold(float): Threshold to filter out bounding boxes with low + confidence score. If not provided, consider all boxes. + nms_threshold(float): The threshold to be used in NMS. + multi_label(bool): Whether keep multi label in boxes. + keep_top_k(int): Number of total bboxes to be kept per image after NMS + step. -1 means keeping all bboxes after NMS step. + """ + + def __init__(self, + score_threshold=0.25, + nms_threshold=0.5, + multi_label=False, + keep_top_k=300): + self.score_threshold = score_threshold + self.nms_threshold = nms_threshold + self.multi_label = multi_label + self.keep_top_k = keep_top_k + + def _xywh2xyxy(self, x): + # Convert from [x, y, w, h] to [x1, y1, x2, y2] + y = np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + def _non_max_suppression(self, prediction): + max_wh = 4096 # (pixels) minimum and maximum box width and height + nms_top_k = 30000 + + cand_boxes = prediction[..., 4] > self.score_threshold # candidates + output = [np.zeros((0, 6))] * prediction.shape[0] + + for batch_id, boxes in enumerate(prediction): + # Apply constraints + boxes = boxes[cand_boxes[batch_id]] + if not boxes.shape[0]: + continue + # Compute conf (conf = obj_conf * cls_conf) + boxes[:, 5:] *= boxes[:, 4:5] + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + convert_box = self._xywh2xyxy(boxes[:, :4]) + + # Detections matrix nx6 (xyxy, conf, cls) + if self.multi_label: + i, j = (boxes[:, 5:] > self.score_threshold).nonzero() + boxes = np.concatenate( + (convert_box[i], boxes[i, j + 5, None], + j[:, None].astype(np.float32)), + axis=1) + else: + conf = np.max(boxes[:, 5:], axis=1) + j = np.argmax(boxes[:, 5:], axis=1) + re = np.array(conf.reshape(-1) > self.score_threshold) + conf = conf.reshape(-1, 1) + j = j.reshape(-1, 1) + boxes = np.concatenate((convert_box, conf, j), axis=1)[re] + + num_box = boxes.shape[0] + if not num_box: + continue + elif num_box > nms_top_k: + boxes = boxes[boxes[:, 4].argsort()[::-1][:nms_top_k]] + + # Batched NMS + c = boxes[:, 5:6] * max_wh + clean_boxes, scores = boxes[:, :4] + c, boxes[:, 4] + keep = nms(clean_boxes, scores, self.nms_threshold) + # limit detection box num + if keep.shape[0] > self.keep_top_k: + keep = keep[:self.keep_top_k] + output[batch_id] = boxes[keep] + return output + + def __call__(self, outs, scale_factor): + preds = self._non_max_suppression(outs) + bboxs, box_nums = [], [] + for i, pred in enumerate(preds): + if len(pred.shape) > 2: + pred = np.squeeze(pred) + if len(pred.shape) == 1: + pred = pred[np.newaxis, :] + pred_bboxes = pred[:, :4] + scale_factor = np.tile(scale_factor[i][::-1], (1, 2)) + pred_bboxes /= scale_factor + bbox = np.concatenate( + [ + pred[:, -1][:, np.newaxis], pred[:, -2][:, np.newaxis], + pred_bboxes + ], + axis=-1) + bboxs.append(bbox) + box_num = bbox.shape[0] + box_nums.append(box_num) + bboxs = np.concatenate(bboxs, axis=0) + box_nums = np.array(box_nums) + return {'bbox': bboxs, 'bbox_num': box_nums} diff --git a/example/auto_compression/pytorch_yolov6/post_quant.py b/example/auto_compression/pytorch_yolov6/post_quant.py new file mode 100644 index 000000000..7fa929dc3 --- /dev/null +++ b/example/auto_compression/pytorch_yolov6/post_quant.py @@ -0,0 +1,101 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import argparse +import paddle +from ppdet.core.workspace import load_config, merge_config +from ppdet.core.workspace import create +from ppdet.metrics import COCOMetric, VOCMetric +from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.quant import quant_post_static + +from post_process import YOLOv6PostProcess + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--save_dir', + type=str, + default='ptq_out', + help="directory to save compressed model.") + parser.add_argument( + '--algo', type=str, default='KL', help="post quant algo.") + + return parser + + +def reader_wrapper(reader, input_list): + def gen(): + for data in reader: + in_dict = {} + if isinstance(input_list, list): + for input_name in input_list: + in_dict[input_name] = data[input_name] + elif isinstance(input_list, dict): + for input_name in input_list.keys(): + in_dict[input_list[input_name]] = data[input_name] + yield in_dict + + return gen + + +def main(): + global global_config + all_config = load_slim_config(FLAGS.config_path) + assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" + global_config = all_config["Global"] + reader_cfg = load_config(global_config['reader_config']) + + train_loader = create('EvalReader')(reader_cfg['TrainDataset'], + reader_cfg['worker_num'], + return_list=True) + train_loader = reader_wrapper(train_loader, global_config['input_list']) + + place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() + exe = paddle.static.Executor(place) + quant_post_static( + executor=exe, + model_dir=global_config["model_dir"], + quantize_model_path=FLAGS.save_dir, + data_loader=train_loader, + model_filename=global_config["model_filename"], + params_filename=global_config["params_filename"], + batch_size=32, + batch_nums=10, + algo=FLAGS.algo, + hist_percent=0.999, + is_full_quantize=False, + bias_correction=False, + onnx_format=False) + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + + assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu'] + paddle.set_device(FLAGS.devices) + + main() diff --git a/example/auto_compression/pytorch_yolov6/run.py b/example/auto_compression/pytorch_yolov6/run.py new file mode 100644 index 000000000..05fe7fdda --- /dev/null +++ b/example/auto_compression/pytorch_yolov6/run.py @@ -0,0 +1,181 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import argparse +import paddle +from ppdet.core.workspace import load_config, merge_config +from ppdet.core.workspace import create +from ppdet.metrics import COCOMetric, VOCMetric +from paddleslim.auto_compression.config_helpers import load_config as load_slim_config +from paddleslim.auto_compression import AutoCompression + +from post_process import YOLOv6PostProcess + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--save_dir', + type=str, + default='output', + help="directory to save compressed model.") + parser.add_argument( + '--devices', + type=str, + default='gpu', + help="which device used to compress.") + parser.add_argument( + '--eval', type=bool, default=False, help="whether to run evaluation.") + + return parser + + +def reader_wrapper(reader, input_list): + def gen(): + for data in reader: + in_dict = {} + if isinstance(input_list, list): + for input_name in input_list: + in_dict[input_name] = data[input_name] + elif isinstance(input_list, dict): + for input_name in input_list.keys(): + in_dict[input_list[input_name]] = data[input_name] + yield in_dict + + return gen + + +def convert_numpy_data(data, metric): + data_all = {} + data_all = {k: np.array(v) for k, v in data.items()} + if isinstance(metric, VOCMetric): + for k, v in data_all.items(): + if not isinstance(v[0], np.ndarray): + tmp_list = [] + for t in v: + tmp_list.append(np.array(t)) + data_all[k] = np.array(tmp_list) + else: + data_all = {k: np.array(v) for k, v in data.items()} + return data_all + + +def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): + metric = global_config['metric'] + for batch_id, data in enumerate(val_loader): + data_all = convert_numpy_data(data, metric) + data_input = {} + for k, v in data.items(): + if isinstance(global_config['input_list'], list): + if k in test_feed_names: + data_input[k] = np.array(v) + elif isinstance(global_config['input_list'], dict): + if k in global_config['input_list'].keys(): + data_input[global_config['input_list'][k]] = np.array(v) + outs = exe.run(compiled_test_program, + feed=data_input, + fetch_list=test_fetch_list, + return_numpy=False) + res = {} + if 'arch' in global_config and global_config['arch'] == 'YOLOv6': + postprocess = YOLOv6PostProcess( + score_threshold=0.001, nms_threshold=0.65, multi_label=True) + res = postprocess(np.array(outs[0]), data_all['scale_factor']) + else: + for out in outs: + v = np.array(out) + if len(v.shape) > 1: + res['bbox'] = v + else: + res['bbox_num'] = v + + metric.update(data_all, res) + if batch_id % 100 == 0: + print('Eval iter:', batch_id) + metric.accumulate() + metric.log() + map_res = metric.get_results() + metric.reset() + return map_res['bbox'][0] + + +def main(): + global global_config + all_config = load_slim_config(FLAGS.config_path) + assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" + global_config = all_config["Global"] + reader_cfg = load_config(global_config['reader_config']) + + train_loader = create('EvalReader')(reader_cfg['TrainDataset'], + reader_cfg['worker_num'], + return_list=True) + train_loader = reader_wrapper(train_loader, global_config['input_list']) + + if 'Evaluation' in global_config.keys() and global_config[ + 'Evaluation'] and paddle.distributed.get_rank() == 0: + eval_func = eval_function + dataset = reader_cfg['EvalDataset'] + global val_loader + _eval_batch_sampler = paddle.io.BatchSampler( + dataset, batch_size=reader_cfg['EvalReader']['batch_size']) + val_loader = create('EvalReader')(dataset, + reader_cfg['worker_num'], + batch_sampler=_eval_batch_sampler, + return_list=True) + metric = None + if reader_cfg['metric'] == 'COCO': + clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} + anno_file = dataset.get_anno() + metric = COCOMetric( + anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox') + elif reader_cfg['metric'] == 'VOC': + metric = VOCMetric( + label_list=dataset.get_label_list(), + class_num=reader_cfg['num_classes'], + map_type=reader_cfg['map_type']) + else: + raise ValueError("metric currently only supports COCO and VOC.") + global_config['metric'] = metric + else: + eval_func = None + + ac = AutoCompression( + model_dir=global_config["model_dir"], + model_filename=global_config["model_filename"], + params_filename=global_config["params_filename"], + save_dir=FLAGS.save_dir, + config=all_config, + train_dataloader=train_loader, + eval_callback=eval_func) + ac.compress() + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + + assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu'] + paddle.set_device(FLAGS.devices) + + main()