Skip to content

Commit

Permalink
[RKNPU2] Update quantitative model (PaddlePaddle#879)
Browse files Browse the repository at this point in the history
* 对RKNPU2后端进行修改,当模型为非量化模型时,不在NPU执行normalize操作,当模型为量化模型时,在NUP上执行normalize操作

* 更新RKNPU2框架,输出数据的数据类型统一返回fp32类型

* 更新scrfd,拆分disable_normalize和disable_permute

* 更新scrfd代码,支持量化

* 更新scrfd python example代码

* 更新模型转换代码,支持量化模型

* 更新文档

* 按照要求修改

* 按照要求修改

* 修正模型转换文档

* 更新一下转换脚本
  • Loading branch information
Zheng-Bicheng authored Dec 19, 2022
1 parent 383887d commit dc13eb7
Show file tree
Hide file tree
Showing 18 changed files with 92 additions and 148 deletions.
67 changes: 10 additions & 57 deletions examples/vision/facedet/scrfd/rknpu2/README.md
Original file line number Diff line number Diff line change
@@ -1,67 +1,20 @@
# SCRFD RKNPU2部署模型


- [SCRFD](https://github.com/deepinsight/insightface/tree/17cdeab12a35efcebc2660453a8cbeae96e20950)
- (1)[官方库](https://github.com/deepinsight/insightface/)中提供的*.pt通过[导出ONNX模型](#导出ONNX模型)操作后,可进行部署;
- (2)开发者基于自己数据训练的SCRFD模型,可按照[导出ONNX模型](#导出ONNX模型)后,完成部署。

## 下载预训练ONNX模型

为了方便开发者的测试,下面提供了SCRFD导出的各系列模型,开发者可直接下载使用。(下表中模型的精度来源于源官方库)
| 模型 | 大小 | 精度 |
|:---------------------------------------------------------------- |:----- |:----- |
| [SCRFD-500M-kps-160](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_500m_bnkps_shape160x160.onnx) | 2.5MB | - |
| [SCRFD-500M-160](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_500m_shape160x160.onnx) | 2.2MB | - |
| [SCRFD-500M-kps-320](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_500m_bnkps_shape320x320.onnx) | 2.5MB | - |
| [SCRFD-500M-320](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_500m_shape320x320.onnx) | 2.2MB | - |
| [SCRFD-500M-kps-640](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_500m_bnkps_shape640x640.onnx) | 2.5MB | 90.97% |
| [SCRFD-500M-640](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_500m_shape640x640.onnx) | 2.2MB | 90.57% |
| [SCRFD-1G-160](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_1g_shape160x160.onnx ) | 2.5MB | - |
| [SCRFD-1G-320](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_1g_shape320x320.onnx) | 2.5MB | - |
| [SCRFD-1G-640](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_1g_shape640x640.onnx) | 2.5MB | 92.38% |
| [SCRFD-2.5G-kps-160](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_2.5g_bnkps_shape160x160.onnx) | 3.2MB | - |
| [SCRFD-2.5G-160](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_2.5g_shape160x160.onnx) | 2.6MB | - |
| [SCRFD-2.5G-kps-320](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_2.5g_bnkps_shape320x320.onnx) | 3.2MB | - |
| [SCRFD-2.5G-320](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_2.5g_shape320x320.onnx) | 2.6MB | - |
| [SCRFD-2.5G-kps-640](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_2.5g_bnkps_shape640x640.onnx) | 3.2MB | 93.8% |
| [SCRFD-2.5G-640](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_2.5g_shape640x640.onnx) | 2.6MB | 93.78% |
| [SCRFD-10G-kps-320](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_10g_bnkps_shape320x320.onnx) | 17MB | - |
| [SCRFD-10G-320](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_10g_shape320x320.onnx) | 15MB | - |
| [SCRFD-10G-kps-640](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_10g_bnkps_shape640x640.onnx) | 17MB | 95.4% |
| [SCRFD-10G-640](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_10g_shape640x640.onnx) | 15MB | 95.16% |
| [SCRFD-10G-kps-1280](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_10g_bnkps_shape1280x1280.onnx) | 17MB | - |
| [SCRFD-10G-1280](https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_10g_shape1280x1280.onnx) | 15MB | - |

## 导出ONNX模型

```bash
#下载scrfd模型文件
e.g. download from https://onedrive.live.com/?authkey=%21ABbFJx2JMhNjhNA&id=4A83B6B633B029CC%215542&cid=4A83B6B633B029CC

# 安装官方库配置环境,此版本导出环境为:
- 手动配置环境
torch==1.8.0
mmcv==1.3.5
mmdet==2.7.0

- 通过docker配置
docker pull qyjdefdocker/onnx-scrfd-converter:v0.3

# 导出onnx格式文件
- 手动生成
python tools/scrfd2onnx.py configs/scrfd/scrfd_500m.py weights/scrfd_500m.pth --shape 640 --input-img face-xxx.jpg

- docker
docker的onnx目录中已有生成好的onnx文件

```
本教程提供SCRFD模型在RKNPU2环境下的部署,模型的详细介绍已经ONNX模型的下载请查看[模型介绍文档](../README.md)

## ONNX模型转换RKNN模型

下面以scrfd_500m_bnkps_shape640x640为例子,快速的转换SCRFD ONNX模型为RKNN量化模型。 以下命令在Ubuntu18.04下执行:
```bash
wget https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_500m_bnkps_shape640x640.onnx
python tools/rknpu2/export.py --config_path tools/rknpu2/config/RK3588/scrfd.yaml
wget https://bj.bcebos.com/paddlehub/fastdeploy/rknpu2/scrfd_500m_bnkps_shape640x640.zip
unzip scrfd_500m_bnkps_shape640x640.zip
python /Path/To/FastDeploy/tools/rknpu2/export.py \
--config_path tools/rknpu2/config/scrfd.yaml \
--target_platform rk3588
```



## 详细部署文档

- [Python部署](python/README.md)
Expand Down
10 changes: 4 additions & 6 deletions examples/vision/facedet/scrfd/rknpu2/cpp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
├── CMakeLists.txt
├── build # 编译文件夹
├── image # 存放图片的文件夹
├── infer_cpu_npu.cc
├── infer_cpu_npu.h
├── main.cc
├── infer.cc
├── model # 存放模型文件的文件夹
└── thirdpartys # 存放sdk的文件夹
```
Expand All @@ -39,9 +37,8 @@ mkdir thirdpartys
请参考[RK2代NPU部署库编译](../../../../../../docs/cn/build_and_install/rknpu2.md)仓库编译SDK,编译完成后,将在build目录下生成
fastdeploy-0.7.0目录,请移动它至thirdpartys目录下.

### 拷贝模型文件,以及配置文件至model文件夹
在Paddle动态图模型 -> Paddle静态图模型 -> ONNX模型的过程中,将生成ONNX文件以及对应的yaml配置文件,请将配置文件存放到model文件夹内。
转换为RKNN后的模型文件也需要拷贝至model。
### 拷贝模型文件至model文件夹
请参考[SCRFD模型转换文档](../README.md)转换SCRFD ONNX模型到RKNN模型,再将RKNN模型移动到model文件夹。

### 准备测试图片至image文件夹
```bash
Expand All @@ -61,6 +58,7 @@ make install

```bash
cd ./build/install
export LD_LIBRARY_PATH=${PWD}/lib:${LD_LIBRARY_PATH}
./rknpu_test
```
运行完成可视化结果如下图所示
Expand Down
5 changes: 4 additions & 1 deletion examples/vision/facedet/scrfd/rknpu2/cpp/infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void ONNXInfer(const std::string& model_dir, const std::string& image_file) {
tc.End();
tc.PrintInfo("SCRFD in ONNX");

std::cout << res.Str() << std::endl;
cv::imwrite("infer_onnx.jpg", vis_im);
std::cout
<< "Visualized result saved in ./infer_onnx.jpg"
Expand All @@ -48,7 +49,8 @@ void RKNPU2Infer(const std::string& model_dir, const std::string& image_file) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
model.DisableNormalizeAndPermute();
model.DisableNormalize();
model.DisablePermute();

fastdeploy::TimeCounter tc;
tc.Start();
Expand All @@ -62,6 +64,7 @@ void RKNPU2Infer(const std::string& model_dir, const std::string& image_file) {
tc.End();
tc.PrintInfo("SCRFD in RKNN");

std::cout << res.Str() << std::endl;
cv::imwrite("infer_rknn.jpg", vis_im);
std::cout
<< "Visualized result saved in ./infer_rknn.jpg"
Expand Down
16 changes: 14 additions & 2 deletions examples/vision/facedet/scrfd/rknpu2/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@

- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../../docs/cn/build_and_install/rknpu2.md)


本目录下提供`infer.py`快速完成SCRFD在RKNPU上部署的示例。执行如下脚本即可完成

## 拷贝模型文件
请参考[SCRFD模型转换文档](../README.md)转换SCRFD ONNX模型到RKNN模型,再将RKNN模型移动到该目录下。


## 运行example
拷贝模型文件后,请输入以下命令,运行RKNPU2 Python example
```bash
# 下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
Expand All @@ -20,10 +25,17 @@ python3 infer.py --model_file ./scrfd_500m_bnkps_shape640x640_rk3588.rknn \
--image test_lite_face_detector_3.jpg
```

## 可视化
运行完成可视化结果如下图所示

<img width="640" src="https://user-images.githubusercontent.com/67993288/184301789-1981d065-208f-4a6b-857c-9a0f9a63e0b1.jpg">



## 注意事项
RKNPU上对模型的输入要求是使用NHWC格式,且图片归一化操作会在转RKNN模型时,内嵌到模型中,因此我们在使用FastDeploy部署时,
需要先调用DisableNormalizePermute(C++)或`disable_normalize_permute(Python),在预处理阶段禁用归一化以及数据格式的转换。
需要先调用DisablePermute(C++)或`disable_permute(Python),在预处理阶段禁用归一化以及数据格式的转换。

## 其它文档

- [SCRFD 模型介绍](../README.md)
Expand Down
3 changes: 2 additions & 1 deletion examples/vision/facedet/scrfd/rknpu2/python/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def build_option(args):
runtime_option=runtime_option,
model_format=fd.ModelFormat.RKNN)

model.disable_normalize_and_permute()
model.disable_normalize()
model.disable_permute()

# 预测图片分割结果
im = cv2.imread(args.image)
Expand Down
17 changes: 5 additions & 12 deletions fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/backends/rknpu/rknpu2/rknpu2_backend.h"
#include "fastdeploy/utils/perf.h"

namespace fastdeploy {
RKNPU2Backend::~RKNPU2Backend() {
// Release memory uniformly here
Expand Down Expand Up @@ -254,12 +254,11 @@ bool RKNPU2Backend::GetModelInputOutputInfos() {
void RKNPU2Backend::DumpTensorAttr(rknn_tensor_attr& attr) {
printf("index=%d, name=%s, n_dims=%d, dims=[%d, %d, %d, %d], "
"n_elems=%d, size=%d, fmt=%s, type=%s, "
"qnt_type=%s, zp=%d, scale=%f, pass_through=%d",
"qnt_type=%s, zp=%d, scale=%f\n",
attr.index, attr.name, attr.n_dims, attr.dims[0], attr.dims[1],
attr.dims[2], attr.dims[3], attr.n_elems, attr.size,
get_format_string(attr.fmt), get_type_string(attr.type),
get_qnt_type_string(attr.qnt_type), attr.zp, attr.scale,
attr.pass_through);
get_qnt_type_string(attr.qnt_type), attr.zp, attr.scale);
}

TensorInfo RKNPU2Backend::GetInputInfo(int index) {
Expand Down Expand Up @@ -310,12 +309,7 @@ bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,
input_attrs_[i].type = input_type;
input_attrs_[i].size = inputs[0].Nbytes();
input_attrs_[i].size_with_stride = inputs[0].Nbytes();
if(input_attrs_[i].type == RKNN_TENSOR_FLOAT16 ||
input_attrs_[i].type == RKNN_TENSOR_FLOAT32){
FDINFO << "The input model is not a quantitative model. "
"Close the normalize operation." << std::endl;
}

input_attrs_[i].pass_through = 0;
input_mems_[i] = rknn_create_mem(ctx, inputs[i].Nbytes());
if (input_mems_[i] == nullptr) {
FDERROR << "rknn_create_mem input_mems_ error." << std::endl;
Expand Down Expand Up @@ -346,7 +340,6 @@ bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,

// default output type is depend on model, this requires float32 to compute top5
ret = rknn_set_io_mem(ctx, output_mems_[i], &output_attrs_[i]);

// set output memory and attribute
if (ret != RKNN_SUCC) {
FDERROR << "output tensor memory rknn_set_io_mem fail! ret=" << ret
Expand All @@ -357,7 +350,7 @@ bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,

this->infer_init = true;
}

// Copy input data to input tensor memory
for (uint32_t i = 0; i < io_num.n_input; i++) {
uint32_t width = input_attrs_[i].dims[2];
Expand Down
15 changes: 11 additions & 4 deletions fastdeploy/vision/facedet/contrib/scrfd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ bool SCRFD::Preprocess(Mat* mat, FDTensor* output,
is_scale_up, stride);

BGR2RGB::Run(mat);
if (!disable_normalize_and_permute_) {
if (!disable_normalize_) {
// Normalize::Run(mat, std::vector<float>(mat->Channels(), 0.0),
// std::vector<float>(mat->Channels(), 1.0));
// Compute `result = mat * alpha + beta` directly by channel
Expand All @@ -150,6 +150,9 @@ bool SCRFD::Preprocess(Mat* mat, FDTensor* output,
std::vector<float> alpha = {1.f / 128.f, 1.f / 128.f, 1.f / 128.f};
std::vector<float> beta = {-127.5f / 128.f, -127.5f / 128.f, -127.5f / 128.f};
Convert::Run(mat, alpha, beta);
}

if(!disable_permute_){
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
}
Expand Down Expand Up @@ -347,7 +350,6 @@ bool SCRFD::Predict(cv::Mat* im, FaceDetectionResult* result,
static_cast<float>(mat.Width())};
im_info["output_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};

if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
Expand All @@ -367,8 +369,13 @@ bool SCRFD::Predict(cv::Mat* im, FaceDetectionResult* result,
}
return true;
}
void SCRFD::DisableNormalizeAndPermute(){
disable_normalize_and_permute_ = true;

void SCRFD::DisableNormalize() {
disable_normalize_=true;
}

void SCRFD::DisablePermute() {
disable_permute_=true;
}
} // namespace facedet
} // namespace vision
Expand Down
10 changes: 7 additions & 3 deletions fastdeploy/vision/facedet/contrib/scrfd.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ class FASTDEPLOY_DECL SCRFD : public FastDeployModel {
unsigned int num_anchors;

/// This function will disable normalize and hwc2chw in preprocessing step.
void DisableNormalizeAndPermute();
void DisableNormalize();

/// This function will disable hwc2chw in preprocessing step.
void DisablePermute();
private:
bool Initialize();

Expand Down Expand Up @@ -121,8 +123,10 @@ class FASTDEPLOY_DECL SCRFD : public FastDeployModel {

std::unordered_map<int, std::vector<SCRFDPoint>> center_points_;

// for recording the switch of normalize and hwc2chw
bool disable_normalize_and_permute_ = false;
// for recording the switch of normalize
bool disable_normalize_ = false;
// for recording the switch of hwc2chw
bool disable_permute_ = false;
};
} // namespace facedet
} // namespace vision
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/vision/facedet/contrib/scrfd_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ void BindSCRFD(pybind11::module& m) {
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
})
.def("disable_normalize_and_permute",&vision::facedet::SCRFD::DisableNormalizeAndPermute)
.def("disable_normalize",&vision::facedet::SCRFD::DisableNormalize)
.def("disable_permute",&vision::facedet::SCRFD::DisablePermute)
.def_readwrite("size", &vision::facedet::SCRFD::size)
.def_readwrite("padding_value", &vision::facedet::SCRFD::padding_value)
.def_readwrite("is_mini_pad", &vision::facedet::SCRFD::is_mini_pad)
Expand Down
12 changes: 9 additions & 3 deletions python/fastdeploy/vision/facedet/contrib/scrfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,17 @@ def predict(self, input_image, conf_threshold=0.7, nms_iou_threshold=0.3):
return self._model.predict(input_image, conf_threshold,
nms_iou_threshold)

def disable_normalize_and_permute(self):
def disable_normalize(self):
"""
This function will disable normalize and hwc2chw in preprocessing step.
This function will disable normalize in preprocessing step.
"""
self._model.disable_normalize_and_permute()
self._model.disable_normalize()

def disable_permute(self):
"""
This function will disable hwc2chw in preprocessing step.
"""
self._model.disable_permute()

# 一些跟SCRFD模型有关的属性封装
# 多数是预处理相关,可通过修改如model.size = [640, 640]改变预处理时resize的大小(前提是模型支持)
Expand Down

This file was deleted.

5 changes: 0 additions & 5 deletions tools/rknpu2/config/RK3568/picodet_s_416_coco_lcnet.yaml

This file was deleted.

7 changes: 0 additions & 7 deletions tools/rknpu2/config/RK3568/scrfd.yaml

This file was deleted.

This file was deleted.

5 changes: 0 additions & 5 deletions tools/rknpu2/config/RK3588/picodet_s_416_coco_lcnet.yaml

This file was deleted.

7 changes: 0 additions & 7 deletions tools/rknpu2/config/RK3588/scrfd.yaml

This file was deleted.

15 changes: 15 additions & 0 deletions tools/rknpu2/config/scrfd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
mean:
-
- 128.5
- 128.5
- 128.5
std:
-
- 128.5
- 128.5
- 128.5
model_path: ./scrfd_500m_bnkps_shape640x640.onnx
outputs_nodes:
do_quantization: True
dataset: "./datasets.txt"
output_folder: "./"
Loading

0 comments on commit dc13eb7

Please sign in to comment.