diff --git a/examples/vision/facedet/blazeface/README.md b/examples/vision/facedet/blazeface/README.md
new file mode 100644
index 0000000000..98c4304121
--- /dev/null
+++ b/examples/vision/facedet/blazeface/README.md
@@ -0,0 +1,34 @@
+English | [简体中文](README_CN.md)
+# BlazeFace Ready-to-deploy Model
+
+- BlazeFace deployment model implementation comes from [BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection),and [Pre-training model based on WiderFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection)
+ - (1)Provided in [Official library
+](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/tools) *.params, could deploy after operation [export_model.py](#Export PADDLE model);
+ - (2)Developers can train BlazeFace model based on their own data according to [export_model. py](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/tools/export_model.py)After exporting the model, complete the deployment。
+
+## Export PADDLE model
+
+Visit [BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection) Github library, download and install according to the instructions, download the `. yml` and `. params` model parameters, and use` export_ Model. py `gets the` pad `model file`. yml,. pdiparams,. pdmodel `.
+
+
+* Download BlazeFace model parameter file
+
+|Network structure | input size | number of pictures/GPU | learning rate strategy | Easy/Media/Hard Set | prediction delay (SD855) | model size (MB) | download | configuration file|
+|:------------:|:--------:|:----:|:-------:|:-------:|:---------:|:----------:|:---------:|:--------:|
+| BlazeFace | 640 | 8 | 1000e | 0.885 / 0.855 / 0.731 | - | 0.472 |[Download link](https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams) | [Config file](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection/blazeface_1000e.yml) |
+| BlazeFace-FPN-SSH | 640 | 8 | 1000e | 0.907 / 0.883 / 0.793 | - | 0.479 |[Download link](https://paddledet.bj.bcebos.com/models/blazeface_fpn_ssh_1000e.pdparams) | [Config file](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection/blazeface_fpn_ssh_1000e.yml) |
+
+* Export paddle-format file
+ ```bash
+ python tools/export_model.py -c configs/face_detection/blazeface_1000e.yml -o weights=blazeface_1000e.pdparams --export_serving_model=True
+ ```
+
+## Detailed Deployment Tutorials
+
+- [Python Deployment](python)
+- [C++ Deployment](cpp)
+
+
+## Release Note
+
+- This tutorial and related code are written based on [BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection)
diff --git a/examples/vision/facedet/blazeface/README_CN.md b/examples/vision/facedet/blazeface/README_CN.md
new file mode 100644
index 0000000000..f3957c0ca0
--- /dev/null
+++ b/examples/vision/facedet/blazeface/README_CN.md
@@ -0,0 +1,31 @@
+# BlazeFace准备部署模型
+
+- BlazeFace部署模型实现来自[BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection),和[基于WiderFace的预训练模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection)
+ - (1)[官方库](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/tools)中提供的*.params,通过[export_model.py](#导出PADDLE模型)操作后,可进行部署;
+ - (2)开发者基于自己数据训练的BlazeFace模型,可按照[export_model.py](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/tools/export_model.py)导出模型后,完成部署。
+
+## 导出PADDLE模型
+
+访问[BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection)github库,按照指引下载安装,下载`.yml`和`.params` 模型参数,利用 `export_model.py` 得到`paddle`模型文件`.yml, .pdiparams, .pdmodel`。
+
+* 下载BlazeFace模型参数文件
+
+| 网络结构 | 输入尺寸 | 图片个数/GPU | 学习率策略 | Easy/Medium/Hard Set | 预测时延(SD855)| 模型大小(MB) | 下载 | 配置文件 |
+|:------------:|:--------:|:----:|:-------:|:-------:|:---------:|:----------:|:---------:|:--------:|
+| BlazeFace | 640 | 8 | 1000e | 0.885 / 0.855 / 0.731 | - | 0.472 |[下载链接](https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection/blazeface_1000e.yml) |
+| BlazeFace-FPN-SSH | 640 | 8 | 1000e | 0.907 / 0.883 / 0.793 | - | 0.479 |[下载链接](https://paddledet.bj.bcebos.com/models/blazeface_fpn_ssh_1000e.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection/blazeface_fpn_ssh_1000e.yml) |
+
+* 导出paddle格式文件
+ ```bash
+ python tools/export_model.py -c configs/face_detection/blazeface_1000e.yml -o weights=blazeface_1000e.pdparams --export_serving_model=True
+ ```
+
+## 详细部署文档
+
+- [Python部署](python)
+- [C++部署](cpp)
+
+
+## 版本说明
+
+- 本版本文档和代码基于[BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection) 编写
diff --git a/examples/vision/facedet/blazeface/cpp/CMakeLists.txt b/examples/vision/facedet/blazeface/cpp/CMakeLists.txt
new file mode 100644
index 0000000000..4ec242a44f
--- /dev/null
+++ b/examples/vision/facedet/blazeface/cpp/CMakeLists.txt
@@ -0,0 +1,14 @@
+PROJECT(infer_demo C CXX)
+CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
+
+# Specifies the path to the fastdeploy library after you have downloaded it
+option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
+
+include(../../../../../FastDeploy.cmake)
+
+# Add the FastDeploy dependency header
+include_directories(${FASTDEPLOY_INCS})
+
+add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
+# Add the FastDeploy library dependency
+target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})
diff --git a/examples/vision/facedet/blazeface/cpp/README.md b/examples/vision/facedet/blazeface/cpp/README.md
new file mode 100644
index 0000000000..dac9fc4434
--- /dev/null
+++ b/examples/vision/facedet/blazeface/cpp/README.md
@@ -0,0 +1,78 @@
+English | [简体中文](README_CN.md)
+# BlazeFace C++ Deployment Example
+
+This directory provides examples that `infer.cc` fast finishes the deployment of BlazeFace on CPU/GPU。
+
+Before deployment, two steps require confirmation
+
+- 1. Software and hardware should meet the requirements. Please refer to [FastDeploy Environment Requirements](../../../../../docs/en/build_and_install/download_prebuilt_libraries.md)
+- 2. Download the precompiled deployment library and samples code according to your development environment. Refer to [FastDeploy Precompiled Library](../../../../../docs/en/build_and_install/download_prebuilt_libraries.md)
+
+Taking the CPU inference on Linux as an example, the compilation test can be completed by executing the following command in this directory.
+
+```bash
+mkdir build
+cd build
+# Download the FastDeploy precompiled library. Users can choose your appropriate version in the `FastDeploy Precompiled Library` mentioned above
+wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz # x.x.x >= 1.0.4
+tar xvf fastdeploy-linux-x64-x.x.x.tgz # x.x.x >= 1.0.4
+cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x # x.x.x >= 1.0.4
+make -j
+
+#Download the official converted YOLOv7Face model files and test images
+wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg
+wget https://bj.bcebos.com/paddlehub/fastdeploy/blzeface-1000e.tgz
+
+#Use blazeface-1000e model
+# CPU inference
+./infer_demo blazeface-1000e/ test_lite_face_detector_3.jpg 0
+# GPU Inference
+./infer_demo blazeface-1000e/ test_lite_face_detector_3.jpg 1
+```
+
+The visualized result after running is as follows
+
+
+
+The above command works for Linux or MacOS. For SDK use-pattern in Windows, refer to:
+- [How to use FastDeploy C++ SDK in Windows](../../../../../docs/cn/faq/use_sdk_on_windows.md)
+
+## BlazeFace C++ Interface
+
+### BlazeFace Class
+
+```c++
+fastdeploy::vision::facedet::BlazeFace(
+ const string& model_file,
+ const string& params_file = "",
+ const string& config_file = "",
+ const RuntimeOption& runtime_option = RuntimeOption(),
+ const ModelFormat& model_format = ModelFormat::PADDLE)
+```
+
+BlazeFace model loading and initialization, among which model_file is the exported PADDLE model format
+
+**Parameter**
+
+> * **model_file**(str): Model file path
+> * **params_file**(str): Parameter file path. Only passing an empty string when the model is in PADDLE format
+> * **config_file**(str): Config file path. Only passing an empty string when the model is in PADDLE format
+> * **runtime_option**(RuntimeOption): Backend inference configuration. None by default, which is the default configuration
+> * **model_format**(ModelFormat): Model format. PADDLE format by default
+
+#### Predict Function
+
+> ```c++
+> BlazeFace::Predict(cv::Mat& im, FaceDetectionResult* result)
+> ```
+>
+> Model prediction interface. Input images and output detection results.
+>
+> **Parameter**
+>
+> > * **im**: Input images in HWC or BGR format
+> > * **result**: Detection results, including detection box and confidence of each box. Refer to [Vision Model Prediction Result](../../../../../docs/api/vision_results/) for FaceDetectionResult
+
+- [Model Description](../../)
+- [Python Deployment](../python)
+- [Vision Model Prediction Results](../../../../../docs/api/vision_results/)
diff --git a/examples/vision/facedet/blazeface/cpp/README_CN.md b/examples/vision/facedet/blazeface/cpp/README_CN.md
new file mode 100644
index 0000000000..12b67b6e5b
--- /dev/null
+++ b/examples/vision/facedet/blazeface/cpp/README_CN.md
@@ -0,0 +1,77 @@
+[English](README.md) | 简体中文
+# BlazeFace C++部署示例
+
+本目录下提供`infer.cc`快速完成BlazeFace在CPU/GPU部署的示例。
+
+在部署前,需确认以下两个步骤
+
+- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
+- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
+
+以Linux上CPU推理为例,在本目录执行如下命令即可完成编译测试
+
+```bash
+mkdir build
+cd build
+# 下载FastDeploy预编译库,用户可在上文提到的`FastDeploy预编译库`中自行选择合适的版本使用
+wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz # x.x.x >= 1.0.4
+tar xvf fastdeploy-linux-x64-x.x.x.tgz # x.x.x >= 1.0.4
+cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x # x.x.x >= 1.0.4
+make -j
+
+#下载官方转换好的BlazeFace模型文件和测试图片
+wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg
+wget https://bj.bcebos.com/paddlehub/fastdeploy/blzeface-1000e.tgz
+
+#使用blazeface-1000e模型
+# CPU推理
+./infer_demo blazeface-1000e/ test_lite_face_detector_3.jpg 0
+# GPU推理
+./infer_demo blazeface-1000e/ test_lite_face_detector_3.jpg 1
+
+运行完成可视化结果如下图所示
+
+
+
+以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
+- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md)
+
+## BlazeFace C++接口
+
+### BlazeFace类
+
+```c++
+fastdeploy::vision::facedet::BlazeFace(
+ const string& model_file,
+ const string& params_file = "",
+ const string& config_file = "",
+ const RuntimeOption& runtime_option = RuntimeOption(),
+ const ModelFormat& model_format = ModelFormat::PADDLE)
+```
+
+BlazeFace模型加载和初始化,其中model_file为导出的PADDLE模型格式。
+
+**参数**
+
+> * **model_file**(str): 模型文件路径
+> * **params_file**(str): 参数文件路径,当模型格式为ONNX时,此参数传入空字符串即可
+> * **config_file**(str): 配置文件路径,当模型格式为ONNX时,此参数传入空字符串即可
+> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
+> * **model_format**(ModelFormat): 模型格式,默认为PADDLE格式
+
+#### Predict函数
+
+> ```c++
+> BlazeFace::Predict(cv::Mat& im, FaceDetectionResult* result)
+> ```
+>
+> 模型预测接口,输入图像直接输出检测结果。
+>
+> **参数**
+>
+> > * **im**: 输入图像,注意需为HWC,BGR格式
+> > * **result**: 检测结果,包括检测框,各个框的置信度, FaceDetectionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
+
+- [模型介绍](../../)
+- [Python部署](../python)
+- [视觉模型预测结果](../../../../../docs/api/vision_results/)
diff --git a/examples/vision/facedet/blazeface/cpp/infer.cc b/examples/vision/facedet/blazeface/cpp/infer.cc
new file mode 100644
index 0000000000..c4304f45ff
--- /dev/null
+++ b/examples/vision/facedet/blazeface/cpp/infer.cc
@@ -0,0 +1,94 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fastdeploy/vision.h"
+
+#ifdef WIN32
+const char sep = '\\';
+#else
+const char sep = '/';
+#endif
+
+void CpuInfer(const std::string& model_dir, const std::string& image_file) {
+ auto model_file = model_dir + sep + "model.pdmodel";
+ auto params_file = model_dir + sep + "model.pdiparams";
+ auto config_file = model_dir + sep + "infer_cfg.yml";
+ auto option = fastdeploy::RuntimeOption();
+ option.UseCpu();
+ auto model = fastdeploy::vision::facedet::BlazeFace(
+ model_file, params_file, config_file, option);
+ if (!model.Initialized()) {
+ std::cerr << "Failed to initialize." << std::endl;
+ return;
+ }
+
+ auto im = cv::imread(image_file);
+
+ fastdeploy::vision::FaceDetectionResult res;
+ if (!model.Predict(im, &res)) {
+ std::cerr << "Failed to predict." << std::endl;
+ return;
+ }
+ std::cout << res.Str() << std::endl;
+
+ auto vis_im = fastdeploy::vision::VisFaceDetection(im, res);
+ cv::imwrite("vis_result.jpg", vis_im);
+ std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
+}
+
+void GpuInfer(const std::string& model_dir, const std::string& image_file) {
+ auto model_file = model_dir + sep + "model.pdmodel";
+ auto params_file = model_dir + sep + "model.pdiparams";
+ auto config_file = model_dir + sep + "infer_cfg.yml";
+ auto option = fastdeploy::RuntimeOption();
+ option.UseGpu();
+ auto model = fastdeploy::vision::facedet::BlazeFace(
+ model_file, params_file, config_file, option);
+ if (!model.Initialized()) {
+ std::cerr << "Failed to initialize." << std::endl;
+ return;
+ }
+
+ auto im = cv::imread(image_file);
+
+ fastdeploy::vision::FaceDetectionResult res;
+ if (!model.Predict(im, &res)) {
+ std::cerr << "Failed to predict." << std::endl;
+ return;
+ }
+ std::cout << res.Str() << std::endl;
+
+ auto vis_im = fastdeploy::vision::VisFaceDetection(im, res);
+ cv::imwrite("vis_result.jpg", vis_im);
+ std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
+}
+
+int main(int argc, char* argv[]) {
+ if (argc < 4) {
+ std::cout << "Usage: infer_demo path/to/model path/to/image run_option, "
+ "e.g ./infer_model yolov5s-face.onnx ./test.jpeg 0"
+ << std::endl;
+ std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
+ "with gpu; 2: run with gpu and use tensorrt backend."
+ << std::endl;
+ return -1;
+ }
+
+ if (std::atoi(argv[3]) == 0) {
+ CpuInfer(argv[1], argv[2]);
+ } else if (std::atoi(argv[3]) == 1) {
+ GpuInfer(argv[1], argv[2]);
+ }
+ return 0;
+}
diff --git a/examples/vision/facedet/blazeface/python/README.md b/examples/vision/facedet/blazeface/python/README.md
new file mode 100644
index 0000000000..b645317cd5
--- /dev/null
+++ b/examples/vision/facedet/blazeface/python/README.md
@@ -0,0 +1,68 @@
+English | [简体中文](README_CN.md)
+# BlazeFace Python Deployment Example
+
+Before deployment, two steps require confirmation
+
+- 1. Software and hardware should meet the requirements. Please refer to [FastDeploy Environment Requirements](../../../../../docs/en/build_and_install/download_prebuilt_libraries.md)
+- 2. Install FastDeploy Python whl package. Refer to [FastDeploy Python Installation](../../../../../docs/en/build_and_install/download_prebuilt_libraries.md)
+
+This directory provides examples that `infer.py` fast finishes the deployment of BlazeFace on CPU/GPU.
+
+```bash
+# Download the example code for deployment
+git clone https://github.com/PaddlePaddle/FastDeploy.git
+cd examples/vision/facedet/blazeface/python/
+
+# Download BlazeFace model files and test images
+wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg
+wget https://bj.bcebos.com/paddlehub/fastdeploy/blazeface-1000e.tgz
+
+# Use blazeface-1000e model
+# CPU Inference
+python infer.py --model blazeface-1000e/ --image test_lite_face_detector_3.jpg --device cpu
+# GPU Inference
+python infer.py --model blazeface-1000e/ --image test_lite_face_detector_3.jpg --device gpu
+```
+
+The visualized result after running is as follows
+
+
+
+## BlazeFace Python Interface
+
+```python
+fastdeploy.vision.facedet.BlzaeFace(model_file, params_file=None, runtime_option=None, config_file=None, model_format=ModelFormat.PADDLE)
+```
+
+BlazeFace model loading and initialization, among which model_file is the exported PADDLE model format
+
+**Parameter**
+
+> * **model_file**(str): Model file path
+> * **params_file**(str): Parameter file path. No need to set when the model is in PADDLE format
+> * **config_file**(str): config file path. No need to set when the model is in PADDLE format
+> * **runtime_option**(RuntimeOption): Backend inference configuration. None by default, which is the default configuration
+> * **model_format**(ModelFormat): Model format. PADDLE format by default
+
+### predict function
+
+> ```python
+> BlazeFace.predict(input_image)
+> ```
+> Through let BlazeFace.postprocessor.conf_threshold = 0.2,to modify conf_threshold
+>
+> Model prediction interface. Input images and output detection results.
+>
+> **Parameter**
+>
+> > * **input_image**(np.ndarray): Input image in HWC or BGR format
+
+> **Return**
+>
+> > Return`fastdeploy.vision.FaceDetectionResult` structure. Refer to [Vision Model Prediction Results](../../../../../docs/api/vision_results/) for its description.
+
+## Other Documents
+
+- [BlazeFace Model Description](..)
+- [BlazeFace C++ Deployment](../cpp)
+- [Model Prediction Results](../../../../../docs/api/vision_results/)
diff --git a/examples/vision/facedet/blazeface/python/README_CN.md b/examples/vision/facedet/blazeface/python/README_CN.md
new file mode 100644
index 0000000000..3bbc620e20
--- /dev/null
+++ b/examples/vision/facedet/blazeface/python/README_CN.md
@@ -0,0 +1,68 @@
+[English](README.md) | 简体中文
+# BlazeFace Python部署示例
+
+在部署前,需确认以下两个步骤
+
+- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
+- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
+
+本目录下提供`infer.py`快速完成BlazeFace在CPU/GPU部署的示例。执行如下脚本即可完成
+
+```bash
+#下载部署示例代码
+git clone https://github.com/PaddlePaddle/FastDeploy.git
+cd examples/vision/facedet/blazeface/python/
+
+#下载BlazeFace模型文件和测试图片
+wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg
+wget https://bj.bcebos.com/paddlehub/fastdeploy/blazeface-1000e.tgz
+
+#使用blazeface-1000e模型
+# CPU推理
+python infer.py --model blazeface-1000e/ --image test_lite_face_detector_3.jpg --device cpu
+# GPU推理
+python infer.py --model blazeface-1000e/ --image test_lite_face_detector_3.jpg --device gpu
+```
+
+运行完成可视化结果如下图所示
+
+
+
+## BlazeFace Python接口
+
+```python
+fastdeploy.vision.facedet.BlzaeFace(model_file, params_file=None, runtime_option=None, config_file=None, model_format=ModelFormat.PADDLE)
+```
+
+BlazeFace模型加载和初始化
+
+**参数**
+
+> * **model_file**(str): 模型文件路径
+> * **params_file**(str): 参数文件路径,当模型格式为ONNX格式时,此参数无需设定
+> * **config_file**(str): config文件路径,当模型格式为ONNX格式时,此参数无需设定
+> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
+> * **model_format**(ModelFormat): 模型格式,默认为PADDLE
+
+### predict函数
+
+> ```python
+> BlazeFace.predict(input_image)
+> ```
+> 通过BlazeFace.postprocessor.conf_threshold = 0.2,来修改conf_threshold
+>
+> 模型预测结口,输入图像直接输出检测结果。
+>
+> **参数**
+>
+> > * **input_image**(np.ndarray): 输入数据,注意需为HWC,BGR格式
+
+> **返回**
+>
+> > 返回`fastdeploy.vision.FaceDetectionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
+
+## 其它文档
+
+- [BlazeFace 模型介绍](..)
+- [BlazeFace C++部署](../cpp)
+- [模型预测结果说明](../../../../../docs/api/vision_results/)
diff --git a/examples/vision/facedet/blazeface/python/infer.py b/examples/vision/facedet/blazeface/python/infer.py
new file mode 100644
index 0000000000..b9904f9c0c
--- /dev/null
+++ b/examples/vision/facedet/blazeface/python/infer.py
@@ -0,0 +1,58 @@
+import fastdeploy as fd
+import cv2
+import os
+
+
+def parse_arguments():
+ import argparse
+ import ast
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model", required=True, help="Path of blazeface model dir.")
+ parser.add_argument(
+ "--image", required=True, help="Path of test image file.")
+ parser.add_argument(
+ "--device",
+ type=str,
+ default='cpu',
+ help="Type of inference device, support 'cpu' or 'gpu'.")
+ parser.add_argument(
+ "--use_trt",
+ type=ast.literal_eval,
+ default=False,
+ help="Wether to use tensorrt.")
+ return parser.parse_args()
+
+
+def build_option(args):
+ option = fd.RuntimeOption()
+
+ if args.device.lower() == "gpu":
+ option.use_gpu()
+
+ if args.use_trt:
+ option.use_trt_backend()
+ option.set_trt_input_shape("images", [1, 3, 640, 640])
+ return option
+
+
+args = parse_arguments()
+
+model_dir = args.model
+
+model_file = os.path.join(model_dir, "model.pdmodel")
+params_file = os.path.join(model_dir, "model.pdiparams")
+config_file = os.path.join(model_dir, "infer_cfg.yml")
+
+# Configure runtime and load the model
+runtime_option = build_option(args)
+model = fd.vision.facedet.BlazeFace(model_file, params_file, config_file, runtime_option=runtime_option)
+
+# Predict image detection results
+im = cv2.imread(args.image)
+result = model.predict(im)
+print(result)
+# Visualization of prediction Results
+vis_im = fd.vision.vis_face_detection(im, result)
+cv2.imwrite("visualized_result.jpg", vis_im)
+print("Visualized result save in ./visualized_result.jpg")
diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h
index 8788e889dc..28721160c3 100755
--- a/fastdeploy/vision.h
+++ b/fastdeploy/vision.h
@@ -41,6 +41,7 @@
#include "fastdeploy/vision/facedet/contrib/ultraface.h"
#include "fastdeploy/vision/facedet/contrib/yolov5face.h"
#include "fastdeploy/vision/facedet/contrib/yolov7face/yolov7face.h"
+#include "fastdeploy/vision/facedet/ppdet/blazeface/blazeface.h"
#include "fastdeploy/vision/faceid/contrib/insightface/model.h"
#include "fastdeploy/vision/faceid/contrib/adaface/adaface.h"
#include "fastdeploy/vision/headpose/contrib/fsanet.h"
diff --git a/fastdeploy/vision/facedet/facedet_pybind.cc b/fastdeploy/vision/facedet/facedet_pybind.cc
index e5a62542db..cf12399e2c 100644
--- a/fastdeploy/vision/facedet/facedet_pybind.cc
+++ b/fastdeploy/vision/facedet/facedet_pybind.cc
@@ -20,6 +20,7 @@ void BindRetinaFace(pybind11::module& m);
void BindUltraFace(pybind11::module& m);
void BindYOLOv5Face(pybind11::module& m);
void BindYOLOv7Face(pybind11::module& m);
+void BindBlazeFace(pybind11::module& m);
void BindSCRFD(pybind11::module& m);
void BindFaceDet(pybind11::module& m) {
@@ -28,6 +29,7 @@ void BindFaceDet(pybind11::module& m) {
BindUltraFace(facedet_module);
BindYOLOv5Face(facedet_module);
BindYOLOv7Face(facedet_module);
+ BindBlazeFace(facedet_module);
BindSCRFD(facedet_module);
}
} // namespace fastdeploy
diff --git a/fastdeploy/vision/facedet/ppdet/blazeface/blazeface.cc b/fastdeploy/vision/facedet/ppdet/blazeface/blazeface.cc
new file mode 100644
index 0000000000..5541f5d676
--- /dev/null
+++ b/fastdeploy/vision/facedet/ppdet/blazeface/blazeface.cc
@@ -0,0 +1,93 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fastdeploy/vision/facedet/ppdet/blazeface/blazeface.h"
+#include "fastdeploy/utils/perf.h"
+#include "fastdeploy/vision/utils/utils.h"
+
+namespace fastdeploy{
+
+namespace vision{
+
+namespace facedet{
+
+BlazeFace::BlazeFace(const std::string& model_file,
+ const std::string& params_file,
+ const std::string& config_file,
+ const RuntimeOption& custom_option,
+ const ModelFormat& model_format)
+ : preprocessor_(config_file){
+ valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::LITE};
+ valid_gpu_backends = {Backend::OPENVINO, Backend::LITE, Backend::PDINFER};
+ runtime_option = custom_option;
+ runtime_option.model_format = model_format;
+ runtime_option.model_file = model_file;
+ runtime_option.params_file = params_file;
+ initialized = Initialize();
+}
+
+bool BlazeFace::Initialize(){
+ if (!InitRuntime()){
+ FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
+ return false;
+ }
+ return true;
+}
+
+bool BlazeFace::Predict(const cv::Mat& im, FaceDetectionResult* result){
+ std::vector results;
+ if (!this->BatchPredict({im}, &results)) {
+ return false;
+ }
+ *result = std::move(results[0]);
+ return true;
+}
+
+bool BlazeFace::BatchPredict(const std::vector& images,
+ std::vector* results){
+ std::vector fd_images = WrapMat(images);
+ FDASSERT(images.size() == 1, "Only support batch = 1 now.");
+ std::vector>> ims_info;
+ if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &ims_info)) {
+ FDERROR << "Failed to preprocess the input image." << std::endl;
+ return false;
+ }
+
+ reused_input_tensors_[0].name = "image";
+ reused_input_tensors_[1].name = "scale_factor";
+ reused_input_tensors_[2].name = "im_shape";
+
+ // Some models don't need scale_factor and im_shape as input
+ while (reused_input_tensors_.size() != NumInputsOfRuntime()) {
+ reused_input_tensors_.pop_back();
+ }
+
+ if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
+ FDERROR << "Failed to inference by runtime." << std::endl;
+ return false;
+ }
+
+ if (!postprocessor_.Run(reused_output_tensors_, results, ims_info)){
+ FDERROR << "Failed to postprocess the inference results by runtime." << std::endl;
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace facedet
+
+} // namespace vision
+
+} // namespace fastdeploy
\ No newline at end of file
diff --git a/fastdeploy/vision/facedet/ppdet/blazeface/blazeface.h b/fastdeploy/vision/facedet/ppdet/blazeface/blazeface.h
new file mode 100644
index 0000000000..b740240a82
--- /dev/null
+++ b/fastdeploy/vision/facedet/ppdet/blazeface/blazeface.h
@@ -0,0 +1,83 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+#include "fastdeploy/fastdeploy_model.h"
+#include "fastdeploy/vision/common/processors/transform.h"
+#include "fastdeploy/vision/common/result.h"
+#include "fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.h"
+#include "fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.h"
+
+namespace fastdeploy {
+
+namespace vision {
+
+namespace facedet {
+/*! @brief BlazeFace model object used when to load a BlazeFace model exported by BlazeFace.
+ */
+class FASTDEPLOY_DECL BlazeFace: public FastDeployModel{
+ public:
+ /** \brief Set path of model file and the configuration of runtime.
+ *
+ * \param[in] model_file Path of model file, e.g ./blazeface.onnx
+ * \param[in] params_file Path of parameter file, e.g ppyoloe/model.pdiparams, if the model format is ONNX, this parameter will be ignored
+ * \param[in] config_file Path of configuration file for deployment, e.g resnet/infer_cfg.yml
+ * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends"
+ * \param[in] model_format Model format of the loaded model, default is ONNX format
+ */
+ BlazeFace(const std::string& model_file, const std::string& params_file = "",
+ const std::string& config_file = "",
+ const RuntimeOption& custom_option = RuntimeOption(),
+ const ModelFormat& model_format = ModelFormat::PADDLE);
+
+ std::string ModelName() {return "blaze-face";}
+
+ /** \brief Predict the detection result for an input image
+ *
+ * \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
+ * \param[in] result The output detection result will be writen to this structure
+ * \return true if the prediction successed, otherwise false
+ */
+ bool Predict(const cv::Mat& im, FaceDetectionResult* result);
+
+ /** \brief Predict the detection results for a batch of input images
+ *
+ * \param[in] imgs, The input image list, each element comes from cv::imread()
+ * \param[in] results The output detection result list
+ * \return true if the prediction successed, otherwise false
+ */
+ virtual bool BatchPredict(const std::vector& images,
+ std::vector* results);
+
+ /// Get preprocessor reference of BlazeFace
+ virtual BlazeFacePreprocessor& GetPreprocessor() {
+ return preprocessor_;
+ }
+
+ /// Get postprocessor reference of BlazeFace
+ virtual BlazeFacePostprocessor& GetPostprocessor() {
+ return postprocessor_;
+ }
+
+ protected:
+ bool Initialize();
+ BlazeFacePreprocessor preprocessor_;
+ BlazeFacePostprocessor postprocessor_;
+};
+
+} // namespace facedet
+
+} // namespace vision
+
+} // namespace fastdeploy
diff --git a/fastdeploy/vision/facedet/ppdet/blazeface/blazeface_pybind.cc b/fastdeploy/vision/facedet/ppdet/blazeface/blazeface_pybind.cc
new file mode 100644
index 0000000000..cc0066d402
--- /dev/null
+++ b/fastdeploy/vision/facedet/ppdet/blazeface/blazeface_pybind.cc
@@ -0,0 +1,84 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fastdeploy/pybind/main.h"
+
+namespace fastdeploy {
+void BindBlazeFace(pybind11::module& m) {
+ pybind11::class_(
+ m, "BlazeFacePreprocessor")
+ .def(pybind11::init<>())
+ .def("run", [](vision::facedet::BlazeFacePreprocessor& self, std::vector& im_list) {
+ std::vector images;
+ for (size_t i = 0; i < im_list.size(); ++i) {
+ images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
+ }
+ std::vector outputs;
+ std::vector>> ims_info;
+ if (!self.Run(&images, &outputs, &ims_info)) {
+ throw std::runtime_error("Failed to preprocess the input data in BlazeFacePreprocessor.");
+ }
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ outputs[i].StopSharing();
+ }
+ return make_pair(outputs, ims_info);
+ });
+
+ pybind11::class_(
+ m, "BlazeFacePostprocessor")
+ .def(pybind11::init<>())
+ .def("run", [](vision::facedet::BlazeFacePostprocessor& self, std::vector& inputs,
+ const std::vector>>& ims_info) {
+ std::vector results;
+ if (!self.Run(inputs, &results, ims_info)) {
+ throw std::runtime_error("Failed to postprocess the runtime result in BlazeFacePostprocessor.");
+ }
+ return results;
+ })
+ .def("run", [](vision::facedet::BlazeFacePostprocessor& self, std::vector& input_array,
+ const std::vector>>& ims_info) {
+ std::vector results;
+ std::vector inputs;
+ PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
+ if (!self.Run(inputs, &results, ims_info)) {
+ throw std::runtime_error("Failed to postprocess the runtime result in BlazePostprocessor.");
+ }
+ return results;
+ })
+ .def_property("conf_threshold", &vision::facedet::BlazeFacePostprocessor::GetConfThreshold, &vision::facedet::BlazeFacePostprocessor::SetConfThreshold)
+ .def_property("nms_threshold", &vision::facedet::BlazeFacePostprocessor::GetNMSThreshold, &vision::facedet::BlazeFacePostprocessor::SetNMSThreshold);
+
+ pybind11::class_(m, "BlazeFace")
+ .def(pybind11::init())
+ .def("predict",
+ [](vision::facedet::BlazeFace& self, pybind11::array& data) {
+ auto mat = PyArrayToCvMat(data);
+ vision::FaceDetectionResult res;
+ self.Predict(mat, &res);
+ return res;
+ })
+ .def("batch_predict", [](vision::facedet::BlazeFace& self, std::vector& data) {
+ std::vector images;
+ for (size_t i = 0; i < data.size(); ++i) {
+ images.push_back(PyArrayToCvMat(data[i]));
+ }
+ std::vector results;
+ self.BatchPredict(images, &results);
+ return results;
+ })
+ .def_property_readonly("preprocessor", &vision::facedet::BlazeFace::GetPreprocessor)
+ .def_property_readonly("postprocessor", &vision::facedet::BlazeFace::GetPostprocessor);
+}
+} // namespace fastdeploy
diff --git a/fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.cc b/fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.cc
new file mode 100644
index 0000000000..8624a5c8ca
--- /dev/null
+++ b/fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.cc
@@ -0,0 +1,96 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.h"
+#include "fastdeploy/vision/utils/utils.h"
+#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
+
+namespace fastdeploy {
+
+namespace vision {
+
+namespace facedet {
+
+BlazeFacePostprocessor::BlazeFacePostprocessor() {
+ conf_threshold_ = 0.5;
+ nms_threshold_ = 0.3;
+}
+
+bool BlazeFacePostprocessor::Run(const std::vector& tensors,
+ std::vector* results,
+ const std::vector>>& ims_info) {
+ // Get number of boxes for each input image
+ std::vector num_boxes(tensors[1].shape[0]);
+ int total_num_boxes = 0;
+ if (tensors[1].dtype == FDDataType::INT32) {
+ const auto* data = static_cast(tensors[1].CpuData());
+ for (size_t i = 0; i < tensors[1].shape[0]; ++i) {
+ num_boxes[i] = static_cast(data[i]);
+ total_num_boxes += num_boxes[i];
+ }
+ } else if (tensors[1].dtype == FDDataType::INT64) {
+ const auto* data = static_cast(tensors[1].CpuData());
+ for (size_t i = 0; i < tensors[1].shape[0]; ++i) {
+ num_boxes[i] = static_cast(data[i]);
+ }
+ }
+
+ // Special case for TensorRT, it has fixed output shape of NMS
+ // So there's invalid boxes in its' output boxes
+ int num_output_boxes = static_cast(tensors[0].Shape()[0]);
+ bool contain_invalid_boxes = false;
+ if (total_num_boxes != num_output_boxes) {
+ if (num_output_boxes % num_boxes.size() == 0) {
+ contain_invalid_boxes = true;
+ } else {
+ FDERROR << "Cannot handle the output data for this model, unexpected "
+ "situation."
+ << std::endl;
+ return false;
+ }
+ }
+
+ // Get boxes for each input image
+ results->resize(num_boxes.size());
+
+ if (tensors[0].shape[0] == 0) {
+ // No detected boxes
+ return true;
+ }
+
+ const auto* box_data = static_cast(tensors[0].CpuData());
+ int offset = 0;
+ for (size_t i = 0; i < num_boxes.size(); ++i) {
+ const float* ptr = box_data + offset;
+ (*results)[i].Reserve(num_boxes[i]);
+ for (size_t j = 0; j < num_boxes[i]; ++j) {
+ if (ptr[j * 6 + 1] > conf_threshold_) {
+ (*results)[i].scores.push_back(ptr[j * 6 + 1]);
+ (*results)[i].boxes.emplace_back(std::array(
+ {ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]}));
+ }
+ }
+ if (contain_invalid_boxes) {
+ offset += static_cast(num_output_boxes * 6 / num_boxes.size());
+ } else {
+ offset += static_cast(num_boxes[i] * 6);
+ }
+ }
+return true;
+}
+
+} // namespace detection
+} // namespace vision
+} // namespace fastdeploy
diff --git a/fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.h b/fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.h
new file mode 100644
index 0000000000..b7443a1409
--- /dev/null
+++ b/fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.h
@@ -0,0 +1,66 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+#include "fastdeploy/vision/common/processors/transform.h"
+#include "fastdeploy/vision/common/result.h"
+
+namespace fastdeploy {
+
+namespace vision {
+
+namespace facedet {
+
+class FASTDEPLOY_DECL BlazeFacePostprocessor{
+ public:
+ /*! @brief Postprocessor object for BlazeFace serials model.
+ */
+ BlazeFacePostprocessor();
+
+ /** \brief Process the result of runtime and fill to FaceDetectionResult structure
+ *
+ * \param[in] infer_result The inference result from runtime
+ * \param[in] results The output result of detection
+ * \param[in] ims_info The shape info list, record input_shape and output_shape
+ * \return true if the postprocess successed, otherwise false
+ */
+ bool Run(const std::vector& infer_result,
+ std::vector* results,
+ const std::vector>>& ims_info);
+
+ /// Set conf_threshold, default 0.5
+ void SetConfThreshold(const float& conf_threshold) {
+ conf_threshold_ = conf_threshold;
+ }
+
+ /// Get conf_threshold, default 0.5
+ float GetConfThreshold() const { return conf_threshold_; }
+
+ /// Set nms_threshold, default 0.3
+ void SetNMSThreshold(const float& nms_threshold) {
+ nms_threshold_ = nms_threshold;
+ }
+
+ /// Get nms_threshold, default 0.3
+ float GetNMSThreshold() const { return nms_threshold_; }
+
+ protected:
+ float conf_threshold_;
+ float nms_threshold_;
+};
+
+} // namespace facedet
+} // namespace vision
+} // namespace fastdeploy
diff --git a/fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.cc b/fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.cc
new file mode 100644
index 0000000000..a259f4a505
--- /dev/null
+++ b/fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.cc
@@ -0,0 +1,207 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.h"
+#include "fastdeploy/function/concat.h"
+#include "fastdeploy/function/pad.h"
+#include "fastdeploy/vision/common/processors/mat.h"
+#include "yaml-cpp/yaml.h"
+
+namespace fastdeploy {
+
+namespace vision {
+
+namespace facedet {
+
+BlazeFacePreprocessor::BlazeFacePreprocessor(const std::string& config_file) {
+ is_scale_ = false;
+ normalize_mean_ = {123, 117, 104};
+ normalize_std_ = {127.502231, 127.502231, 127.502231};
+ this->config_file_ = config_file;
+ FDASSERT(BuildPreprocessPipelineFromConfig(),
+ "Failed to create PaddleDetPreprocessor.");
+}
+
+bool BlazeFacePreprocessor::Run(std::vector* images, std::vector* outputs,
+ std::vector>>* ims_info) {
+ if (images->size() == 0) {
+ FDERROR << "The size of input images should be greater than 0." << std::endl;
+ return false;
+ }
+ ims_info->resize(images->size());
+ outputs->resize(3);
+ int batch = static_cast(images->size());
+ // Allocate memory for scale_factor
+ (*outputs)[1].Resize({batch, 2}, FDDataType::FP32);
+ // Allocate memory for im_shape
+ (*outputs)[2].Resize({batch, 2}, FDDataType::FP32);
+
+ std::vector max_hw({-1, -1});
+
+ auto* scale_factor_ptr =
+ reinterpret_cast((*outputs)[1].MutableData());
+ auto* im_shape_ptr = reinterpret_cast((*outputs)[2].MutableData());
+
+ // Concat all the preprocessed data to a batch tensor
+ std::vector im_tensors(images->size());
+
+ for (size_t i = 0; i < images->size(); ++i) {
+ int origin_w = (*images)[i].Width();
+ int origin_h = (*images)[i].Height();
+ scale_factor_ptr[2 * i] = 1.0;
+ scale_factor_ptr[2 * i + 1] = 1.0;
+
+ for (size_t j = 0; j < processors_.size(); ++j) {
+ if (!(*(processors_[j].get()))(&((*images)[i]))) {
+ FDERROR << "Failed to processs image:" << i << " in "
+ << processors_[i]->Name() << "." << std::endl;
+ return false;
+ }
+ if (processors_[j]->Name().find("Resize") != std::string::npos) {
+ scale_factor_ptr[2 * i] = (*images)[i].Height() * 1.0 / origin_h;
+ scale_factor_ptr[2 * i + 1] = (*images)[i].Width() * 1.0 / origin_w;
+ }
+ }
+
+ if ((*images)[i].Height() > max_hw[0]) {
+ max_hw[0] = (*images)[i].Height();
+ }
+ if ((*images)[i].Width() > max_hw[1]) {
+ max_hw[1] = (*images)[i].Width();
+ }
+ im_shape_ptr[2 * i] = max_hw[0];
+ im_shape_ptr[2 * i + 1] = max_hw[1];
+
+ if ((*images)[i].Height() < max_hw[0] || (*images)[i].Width() < max_hw[1]) {
+ // if the size of image less than max_hw, pad to max_hw
+ FDTensor tensor;
+ (*images)[i].ShareWithTensor(&tensor);
+ function::Pad(tensor, &(im_tensors[i]),
+ {0, 0, max_hw[0] - (*images)[i].Height(),
+ max_hw[1] - (*images)[i].Width()},
+ 0);
+ } else {
+ // No need pad
+ (*images)[i].ShareWithTensor(&(im_tensors[i]));
+ }
+ // Reshape to 1xCxHxW
+ im_tensors[i].ExpandDim(0);
+ }
+
+ if (im_tensors.size() == 1) {
+ // If there's only 1 input, no need to concat
+ // skip memory copy
+ (*outputs)[0] = std::move(im_tensors[0]);
+ } else {
+ // Else concat the im tensor for each input image
+ // compose a batched input tensor
+ function::Concat(im_tensors, &((*outputs)[0]), 0);
+ }
+
+ return true;
+}
+
+bool BlazeFacePreprocessor::BuildPreprocessPipelineFromConfig() {
+ processors_.clear();
+ YAML::Node cfg;
+ try {
+ cfg = YAML::LoadFile(config_file_);
+ } catch (YAML::BadFile& e) {
+ FDERROR << "Failed to load yaml file " << config_file_
+ << ", maybe you should check this file." << std::endl;
+ return false;
+ }
+
+ processors_.push_back(std::make_shared());
+
+ bool has_permute = false;
+ for (const auto& op : cfg["Preprocess"]) {
+ std::string op_name = op["type"].as();
+ if (op_name == "NormalizeImage") {
+ auto mean = op["mean"].as>();
+ auto std = op["std"].as>();
+ bool is_scale = true;
+ if (op["is_scale"]) {
+ is_scale = op["is_scale"].as();
+ }
+ std::string norm_type = "mean_std";
+ if (op["norm_type"]) {
+ norm_type = op["norm_type"].as();
+ }
+ if (norm_type != "mean_std") {
+ std::fill(mean.begin(), mean.end(), 0.0);
+ std::fill(std.begin(), std.end(), 1.0);
+ }
+ processors_.push_back(std::make_shared(mean, std, is_scale));
+ } else if (op_name == "Resize") {
+ bool keep_ratio = op["keep_ratio"].as();
+ auto target_size = op["target_size"].as>();
+ int interp = op["interp"].as();
+ FDASSERT(target_size.size() == 2,
+ "Require size of target_size be 2, but now it's %lu.",
+ target_size.size());
+ if (!keep_ratio) {
+ int width = target_size[1];
+ int height = target_size[0];
+ processors_.push_back(
+ std::make_shared(width, height, -1.0, -1.0, interp, false));
+ } else {
+ int min_target_size = std::min(target_size[0], target_size[1]);
+ int max_target_size = std::max(target_size[0], target_size[1]);
+ std::vector max_size;
+ if (max_target_size > 0) {
+ max_size.push_back(max_target_size);
+ max_size.push_back(max_target_size);
+ }
+ processors_.push_back(std::make_shared(
+ min_target_size, interp, true, max_size));
+ }
+ } else if (op_name == "Permute") {
+ // Do nothing, do permute as the last operation
+ has_permute = true;
+ continue;
+ } else if (op_name == "Pad") {
+ auto size = op["size"].as>();
+ auto value = op["fill_value"].as>();
+ processors_.push_back(std::make_shared("float"));
+ processors_.push_back(
+ std::make_shared(size[1], size[0], value));
+ } else if (op_name == "PadStride") {
+ auto stride = op["stride"].as();
+ processors_.push_back(
+ std::make_shared(stride, std::vector(3, 0)));
+ } else {
+ FDERROR << "Unexcepted preprocess operator: " << op_name << "."
+ << std::endl;
+ return false;
+ }
+ }
+
+ if (has_permute) {
+ // permute = cast + HWC2CHW
+ processors_.push_back(std::make_shared("float"));
+ processors_.push_back(std::make_shared());
+ }
+
+ // Fusion will improve performance
+ FuseTransforms(&processors_);
+
+ return true;
+}
+
+} // namespace facedet
+
+} // namespace vision
+
+} // namespacefastdeploy
\ No newline at end of file
diff --git a/fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.h b/fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.h
new file mode 100644
index 0000000000..836fd6bfb4
--- /dev/null
+++ b/fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.h
@@ -0,0 +1,69 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+#include "fastdeploy/vision/common/processors/transform.h"
+#include "fastdeploy/vision/common/result.h"
+#include "fastdeploy/vision/detection/ppdet/preprocessor.h"
+
+namespace fastdeploy {
+
+namespace vision {
+
+namespace facedet {
+
+class FASTDEPLOY_DECL BlazeFacePreprocessor:
+ public fastdeploy::vision::detection::PaddleDetPreprocessor {
+ public:
+ /** \brief Create a preprocessor instance for BlazeFace serials model
+ */
+ BlazeFacePreprocessor() = default;
+
+ /** \brief Create a preprocessor instance for Blazeface serials model
+ *
+ * \param[in] config_file Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
+ */
+ explicit BlazeFacePreprocessor(const std::string& config_file);
+
+ /** \brief Process the input image and prepare input tensors for runtime
+ *
+ * \param[in] images The input image data list, all the elements are returned by cv::imread()
+ * \param[in] outputs The output tensors which will feed in runtime
+ * \param[in] ims_info The shape info list, record input_shape and output_shape
+ * \ret
+ */
+ bool Run(std::vector* images, std::vector* outputs,
+ std::vector>>* ims_info);
+
+ private:
+ bool BuildPreprocessPipelineFromConfig();
+
+ // if is_scale_up is false, the input image only can be zoom out,
+ // the maximum resize scale cannot exceed 1.0
+ bool is_scale_;
+
+ std::vector normalize_mean_;
+
+ std::vector normalize_std_;
+
+ std::vector> processors_;
+ // read config file
+ std::string config_file_;
+};
+
+} // namespace facedet
+
+} // namespace vision
+
+} // namespace fastdeploy
diff --git a/python/fastdeploy/vision/facedet/__init__.py b/python/fastdeploy/vision/facedet/__init__.py
index 869657a3c7..a96cb791c8 100644
--- a/python/fastdeploy/vision/facedet/__init__.py
+++ b/python/fastdeploy/vision/facedet/__init__.py
@@ -15,6 +15,7 @@
from __future__ import absolute_import
from .contrib.yolov5face import YOLOv5Face
from .contrib.yolov7face import *
+from .contrib.blazeface import *
from .contrib.retinaface import RetinaFace
from .contrib.scrfd import SCRFD
from .contrib.ultraface import UltraFace
diff --git a/python/fastdeploy/vision/facedet/contrib/blazeface.py b/python/fastdeploy/vision/facedet/contrib/blazeface.py
new file mode 100644
index 0000000000..f67b6ee3d2
--- /dev/null
+++ b/python/fastdeploy/vision/facedet/contrib/blazeface.py
@@ -0,0 +1,143 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import logging
+from .... import FastDeployModel, ModelFormat
+from .... import c_lib_wrap as C
+
+
+class BlazeFacePreprocessor:
+ def __init__(self):
+ """Create a preprocessor for BlazeFace
+ """
+ self._preprocessor = C.vision.facedet.BlazeFacePreprocessor()
+
+ def run(self, input_ims):
+ """Preprocess input images for BlazeFace
+
+ :param: input_ims: (list of numpy.ndarray)The input image
+ :return: list of FDTensor
+ """
+ return self._preprocessor.run(input_ims)
+
+ @property
+ def is_scale_(self):
+ """
+ is_scale_ for preprocessing, the input image only can be zoom out, the maximum resize scale cannot exceed 1.0, default true
+ """
+ return self._preprocessor.is_scale_
+
+ @is_scale_.setter
+ def is_scale_(self, value):
+ assert isinstance(
+ value,
+ bool), "The value to set `is_scale_` must be type of bool."
+ self._preprocessor.is_scale_ = value
+
+
+class BlazeFacePostprocessor:
+ def __init__(self):
+ """Create a postprocessor for BlazeFace
+ """
+ self._postprocessor = C.vision.facedet.BlazeFacePostprocessor()
+
+ def run(self, runtime_results, ims_info):
+ """Postprocess the runtime results for BlazeFace
+
+ :param: runtime_results: (list of FDTensor)The output FDTensor results from runtime
+ :param: ims_info: (list of dict)Record input_shape and output_shape
+ :return: list of DetectionResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
+ """
+ return self._postprocessor.run(runtime_results, ims_info)
+
+ @property
+ def conf_threshold(self):
+ """
+ confidence threshold for postprocessing, default is 0.5
+ """
+ return self._postprocessor.conf_threshold
+
+ @property
+ def nms_threshold(self):
+ """
+ nms threshold for postprocessing, default is 0.3
+ """
+ return self._postprocessor.nms_threshold
+
+ @conf_threshold.setter
+ def conf_threshold(self, conf_threshold):
+ assert isinstance(conf_threshold, float),\
+ "The value to set `conf_threshold` must be type of float."
+ self._postprocessor.conf_threshold = conf_threshold
+
+ @nms_threshold.setter
+ def nms_threshold(self, nms_threshold):
+ assert isinstance(nms_threshold, float),\
+ "The value to set `nms_threshold` must be type of float."
+ self._postprocessor.nms_threshold = nms_threshold
+
+
+class BlazeFace(FastDeployModel):
+ def __init__(self,
+ model_file,
+ params_file="",
+ config_file="",
+ runtime_option=None,
+ model_format=ModelFormat.PADDLE):
+ """Load a BlazeFace model exported by BlazeFace.
+
+ :param model_file: (str)Path of model file, e.g ./Blazeface.onnx
+ :param params_file: (str)Path of parameters file, e.g yolox/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
+ :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
+ :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
+ """
+ super(BlazeFace, self).__init__(runtime_option)
+
+ self._model = C.vision.facedet.BlazeFace(
+ model_file, params_file, config_file, self._runtime_option, model_format)
+
+ assert self.initialized, "BlazeFace initialize failed."
+
+ def predict(self, input_image):
+ """Detect the location and key points of human faces from an input image
+ :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
+ :return: FaceDetectionResult
+ """
+ return self._model.predict(input_image)
+
+ def batch_predict(self, images):
+ """Classify a batch of input image
+
+ :param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
+ :return list of FaceDetectionResult
+ """
+
+ return self._model.batch_predict(images)
+
+ @property
+ def preprocessor(self):
+ """Get BlazefacePreprocessor object of the loaded model
+
+ :return BlazefacePreprocessor
+ """
+ return self._model.preprocessor
+
+ @property
+ def postprocessor(self):
+ """Get BlazefacePostprocessor object of the loaded model
+
+ :return BlazefacePostprocessor
+ """
+ return self._model.postprocessor
diff --git a/tests/models/test_blazeface.py b/tests/models/test_blazeface.py
new file mode 100644
index 0000000000..70bafd6938
--- /dev/null
+++ b/tests/models/test_blazeface.py
@@ -0,0 +1,151 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from fastdeploy import ModelFormat
+import fastdeploy as fd
+import cv2
+import os
+import pickle
+import numpy as np
+import runtime_config as rc
+
+
+def test_detection_blazeface():
+ model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_1000e.tgz"
+ input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
+ input_url2 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000570688.jpg"
+ result_url1 = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_result1.pkl"
+ result_url2 = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_result2.pkl"
+ fd.download_and_decompress(model_url, "resources")
+ fd.download(input_url1, "resources")
+ fd.download(input_url2, "resources")
+
+
+ model_dir = "resources/blazeface_1000e"
+ model_file = os.path.join(model_dir, "model.pdmodel")
+ params_file = os.path.join(model_dir, "model.pdiparams")
+ config_file = os.path.join(model_dir, "infer_cfg.yml")
+ model = fd.vision.facedet.BlazeFace(
+ model_file, params_file, config_file, runtime_option=rc.test_option)
+ model.postprocessor.conf_threshold = 0.5
+
+ with open("resources/blazeface_result1.pkl", "rb") as f:
+ expect1 = pickle.load(f)
+
+ with open("resources/blazeface_result2.pkl", "rb") as f:
+ expect2 = pickle.load(f)
+
+ im1 = cv2.imread("./resources/000000014439.jpg")
+ im2 = cv2.imread("./resources/000000570688.jpg")
+
+ for i in range(3):
+ # test single predict
+ result1 = model.predict(im1)
+ result2 = model.predict(im2)
+
+ diff_boxes_1 = np.fabs(
+ np.array(result1.boxes) - np.array(expect1["boxes"]))
+ diff_boxes_2 = np.fabs(
+ np.array(result2.boxes) - np.array(expect2["boxes"]))
+
+ diff_scores_1 = np.fabs(
+ np.array(result1.scores) - np.array(expect1["scores"]))
+ diff_scores_2 = np.fabs(
+ np.array(result2.scores) - np.array(expect2["scores"]))
+
+ assert diff_boxes_1.max(
+ ) < 1e-04, "There's difference in detection boxes 1."
+ assert diff_scores_1.max(
+ ) < 1e-04, "There's difference in detection score 1."
+
+ assert diff_boxes_2.max(
+ ) < 1e-03, "There's difference in detection boxes 2."
+ assert diff_scores_2.max(
+ ) < 1e-04, "There's difference in detection score 2."
+
+ print("one image test success!")
+
+ # test batch predict
+ results = model.batch_predict([im1, im2])
+ result1 = results[0]
+ result2 = results[1]
+
+ diff_boxes_1 = np.fabs(
+ np.array(result1.boxes) - np.array(expect1["boxes"]))
+ diff_boxes_2 = np.fabs(
+ np.array(result2.boxes) - np.array(expect2["boxes"]))
+
+ diff_scores_1 = np.fabs(
+ np.array(result1.scores) - np.array(expect1["scores"]))
+ diff_scores_2 = np.fabs(
+ np.array(result2.scores) - np.array(expect2["scores"]))
+ assert diff_boxes_1.max(
+ ) < 1e-04, "There's difference in detection boxes 1."
+ assert diff_scores_1.max(
+ ) < 1e-03, "There's difference in detection score 1."
+
+ assert diff_boxes_2.max(
+ ) < 1e-04, "There's difference in detection boxes 2."
+ assert diff_scores_2.max(
+ ) < 1e-04, "There's difference in detection score 2."
+
+ print("batch predict success!")
+
+
+def test_detection_blazeface_runtime():
+ model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_1000e.tgz"
+ input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
+ result_url1 = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_result1.pkl"
+ fd.download_and_decompress(model_url, "resources")
+ fd.download(input_url1, "resources")
+ fd.download(result_url1, "resources")
+
+ model_dir = "resources/blazeface_1000e"
+ model_file = os.path.join(model_dir, "model.pdmodel")
+ params_file = os.path.join(model_dir, "model.pdiparams")
+ config_file = os.path.join(model_dir, "infer_cfg.yml")
+
+ preprocessor = fd.vision.facedet.BlazeFacePreprocessor()
+ postprocessor = fd.vision.facedet.BlazeFacePostprocessor()
+
+ rc.test_option.set_model_path(model_file, params_file, config_file, model_format=ModelFormat.PADDLE)
+ rc.test_option.use_openvino_backend()
+ runtime = fd.Runtime(rc.test_option)
+
+ with open("resources/blazeface_result1.pkl", "rb") as f:
+ expect1 = pickle.load(f)
+
+ im1 = cv2.imread("resources/000000014439.jpg")
+
+ for i in range(3):
+ # test runtime
+ input_tensors, ims_info = preprocessor.run([im1.copy()])
+ output_tensors = runtime.infer({"images": input_tensors[0]})
+ results = postprocessor.run(output_tensors, ims_info)
+ result1 = results[0]
+
+ diff_boxes_1 = np.fabs(
+ np.array(result1.boxes) - np.array(expect1["boxes"]))
+ diff_scores_1 = np.fabs(
+ np.array(result1.scores) - np.array(expect1["scores"]))
+
+ assert diff_boxes_1.max(
+ ) < 1e-03, "There's difference in detection boxes 1."
+ assert diff_scores_1.max(
+ ) < 1e-04, "There's difference in detection score 1."
+
+
+if __name__ == "__main__":
+ test_detection_blazeface()
+ test_detection_blaze_runtime()