Skip to content

Commit

Permalink
[Model] Add YOLOv5-seg (PaddlePaddle#988)
Browse files Browse the repository at this point in the history
* add onnx_ort_runtime demo

* rm in requirements

* support batch eval

* fixed MattingResults bug

* move assignment for DetectionResult

* integrated x2paddle

* add model convert readme

* update readme

* re-lint

* add processor api

* Add MattingResult Free

* change valid_cpu_backends order

* add ppocr benchmark

* mv bs from 64 to 32

* fixed quantize.md

* fixed quantize bugs

* Add Monitor for benchmark

* update mem monitor

* Set trt_max_batch_size default 1

* fixed ocr benchmark bug

* support yolov5 in serving

* Fixed yolov5 serving

* Fixed postprocess

* update yolov5 to 7.0

* add poros runtime demos

* update readme

* Support poros abi=1

* rm useless note

* deal with comments

* support pp_trt for ppseg

* fixed symlink problem

* Add is_mini_pad and stride for yolov5

* Add yolo series for paddle format

* fixed bugs

* fixed bug

* support yolov5seg

* fixed bug

* refactor yolov5seg

* fixed bug

* mv Mask int32 to uint8

* add yolov5seg example

* rm log info

* fixed code style

* add yolov5seg example in python

* fixed dtype bug

* update note

* deal with comments

* get sorted index

* add yolov5seg test case

* Add GPL-3.0 License

* add round func

* deal with comments

* deal with commens

Co-authored-by: Jason <[email protected]>
  • Loading branch information
wjj19950828 and jiangjiajun authored Jan 11, 2023
1 parent 60e6a12 commit aa6931b
Show file tree
Hide file tree
Showing 28 changed files with 1,607 additions and 33 deletions.
Empty file modified cmake/paddle_inference.cmake
100644 → 100755
Empty file.
27 changes: 27 additions & 0 deletions examples/vision/detection/yolov5seg/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# YOLOv5Seg准备部署模型

- YOLOv5Seg v7.0部署模型实现来自[YOLOv5](https://github.com/ultralytics/yolov5/tree/v7.0),和[基于COCO的预训练模型](https://github.com/ultralytics/yolov5/releases/tag/v7.0)
- (1)[官方库](https://github.com/ultralytics/yolov5/releases/tag/v7.0)提供的*.onnx可直接进行部署;
- (2)开发者基于自己数据训练的YOLOv5Seg v7.0模型,可使用[YOLOv5](https://github.com/ultralytics/yolov5)中的`export.py`导出ONNX文件后,完成部署。


## 下载预训练ONNX模型

为了方便开发者的测试,下面提供了YOLOv5Seg导出的各系列模型,开发者可直接下载使用。(下表中模型的精度来源于源官方库)
| 模型 | 大小 | 精度 | 备注 |
|:---------------------------------------------------------------- |:----- |:----- |:----- |
| [YOLOv5n-seg](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5n-seg.onnx) | 7.7MB | 27.6% | 此模型文件来源于[YOLOv5](https://github.com/ultralytics/yolov5),GPL-3.0 License |
| [YOLOv5s-seg](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s-seg.onnx) | 30MB | 37.6% | 此模型文件来源于[YOLOv5](https://github.com/ultralytics/yolov5),GPL-3.0 License |
| [YOLOv5m-seg](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5m-seg.onnx) | 84MB | 45.0% | 此模型文件来源于[YOLOv5](https://github.com/ultralytics/yolov5),GPL-3.0 License |
| [YOLOv5l-seg](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5l-seg.onnx) | 183MB | 49.0% | 此模型文件来源于[YOLOv5](https://github.com/ultralytics/yolov5),GPL-3.0 License |
| [YOLOv5x-seg](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5x-seg.onnx) | 339MB | 50.7% | 此模型文件来源于[YOLOv5](https://github.com/ultralytics/yolov5),GPL-3.0 License |


## 详细部署文档

- [Python部署](python)
- [C++部署](cpp)

## 版本说明

- 本版本文档和代码基于[YOLOv5 v7.0](https://github.com/ultralytics/yolov5/tree/v7.0) 编写
14 changes: 14 additions & 0 deletions examples/vision/detection/yolov5seg/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)

# Specify the fastdeploy library path after downloading and decompression
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")

include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)

# Add FastDeploy dependent header files
include_directories(${FASTDEPLOY_INCS})

add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
# Add FastDeploy library dependencies
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})
74 changes: 74 additions & 0 deletions examples/vision/detection/yolov5seg/cpp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# YOLOv5Seg C++部署示例

本目录下提供`infer.cc`快速完成YOLOv5Seg在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。

在部署前,需确认以下两个步骤

- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)

以Linux上CPU推理为例,在本目录执行如下命令即可完成编译测试,支持此模型需保证FastDeploy版本1.0.3以上(x.x.x>=1.0.3)

```bash
mkdir build
cd build
# 下载 FastDeploy 预编译库,用户可在上文提到的`FastDeploy预编译库`中自行选择合适的版本使用
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz
tar xvf fastdeploy-linux-x64-x.x.x.tgz
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x
make -j

# 1. 下载官方转换好的 YOLOv5Seg ONNX 模型文件和测试图片
wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s-seg.onnx
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg

# CPU推理
./infer_demo yolov5s-seg.onnx 000000014439.jpg 0
# GPU推理
./infer_demo yolov5s-seg.onnx 000000014439.jpg 1
# GPU上TensorRT推理
./infer_demo yolov5s-seg.onnx 000000014439.jpg 2
```
运行完成可视化结果如下图所示

<img width="640" src="https://user-images.githubusercontent.com/19977378/209955620-657bdd1d-574c-40a2-b05d-42b9e5a15ae8.png">

以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md)

## YOLOv5Seg C++接口

### YOLOv5Seg类

```c++
fastdeploy::vision::detection::YOLOv5Seg(
const string& model_file,
const string& params_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::ONNX)
```
YOLOv5Seg模型加载和初始化,其中model_file为导出的ONNX模型格式。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径,当模型格式为ONNX时,此参数传入空字符串即可
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
> * **model_format**(ModelFormat): 模型格式,默认为ONNX格式
#### Predict函数
```c++
YOLOv5Seg::Predict(const cv::Mat& img, DetectionResult* result)
```

**参数**

> > * **im**: 输入图像,注意需为HWC,BGR格式
> > * **result**: 检测结果,包括检测框,各个框的置信度, DetectionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
- [模型介绍](../../)
- [Python部署](../python)
- [视觉模型预测结果](../../../../../docs/api/vision_results/)
- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md)
105 changes: 105 additions & 0 deletions examples/vision/detection/yolov5seg/cpp/infer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "fastdeploy/vision.h"

void CpuInfer(const std::string& model_file, const std::string& image_file) {
auto model = fastdeploy::vision::detection::YOLOv5Seg(model_file);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}

auto im = cv::imread(image_file);

fastdeploy::vision::DetectionResult res;
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;

auto vis_im = fastdeploy::vision::VisDetection(im, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}

void GpuInfer(const std::string& model_file, const std::string& image_file) {
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
auto model = fastdeploy::vision::detection::YOLOv5Seg(model_file, "", option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}

auto im = cv::imread(image_file);

fastdeploy::vision::DetectionResult res;
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;

auto vis_im = fastdeploy::vision::VisDetection(im, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}

void TrtInfer(const std::string& model_file, const std::string& image_file) {
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
option.UseTrtBackend();
option.SetTrtInputShape("images", {1, 3, 640, 640});
auto model = fastdeploy::vision::detection::YOLOv5Seg(model_file, "", option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}

auto im = cv::imread(image_file);

fastdeploy::vision::DetectionResult res;
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;

auto vis_im = fastdeploy::vision::VisDetection(im, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}

int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout << "Usage: infer_demo path/to/model path/to/image run_option, "
"e.g ./infer_model ./yolov5.onnx ./test.jpeg 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu; 2: run with gpu and use tensorrt backend."
<< std::endl;
return -1;
}

if (std::atoi(argv[3]) == 0) {
CpuInfer(argv[1], argv[2]);
} else if (std::atoi(argv[3]) == 1) {
GpuInfer(argv[1], argv[2]);
} else if (std::atoi(argv[3]) == 2) {
TrtInfer(argv[1], argv[2]);
}
return 0;
}
67 changes: 67 additions & 0 deletions examples/vision/detection/yolov5seg/python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# YOLOv5Seg Python部署示例

在部署前,需确认以下两个步骤

- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)

本目录下提供`infer.py`快速完成YOLOv5Seg在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成

```bash
#下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd examples/vision/detection/yolov5seg/python/

#下载yolov5seg模型文件和测试图片
wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s-seg.onnx
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg

# CPU推理
python infer.py --model yolov5s-seg.onnx --image 000000014439.jpg --device cpu
# GPU推理
python infer.py --model yolov5s-seg.onnx --image 000000014439.jpg --device gpu
# GPU上使用TensorRT推理
python infer.py --model yolov5s-seg.onnx --image 000000014439.jpg --device gpu --use_trt True
```

运行完成可视化结果如下图所示

<img width="640" src="https://user-images.githubusercontent.com/19977378/209955620-657bdd1d-574c-40a2-b05d-42b9e5a15ae8.png">

## YOLOv5Seg Python接口

```python
fastdeploy.vision.detection.YOLOv5Seg(model_file, params_file=None, runtime_option=None, model_format=ModelFormat.ONNX)
```

YOLOv5Seg模型加载和初始化,其中model_file为导出的ONNX模型格式

**参数**

> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径,当模型格式为ONNX格式时,此参数无需设定
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
> * **model_format**(ModelFormat): 模型格式,默认为ONNX
### predict函数

```python
YOLOv5Seg.predict(image_data)
```

模型预测结口,输入图像直接输出检测结果。

**参数**

> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式
**返回**

> > 返回`fastdeploy.vision.DetectionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
## 其它文档

- [YOLOv5Seg 模型介绍](..)
- [YOLOv5Seg C++部署](../cpp)
- [模型预测结果说明](../../../../../docs/api/vision_results/)
- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md)
56 changes: 56 additions & 0 deletions examples/vision/detection/yolov5seg/python/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import fastdeploy as fd
import cv2
import os


def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", default=None, help="Path of yolov5seg model.")
parser.add_argument(
"--image", default=None, help="Path of test image file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
return parser.parse_args()


def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu()

if args.use_trt:
option.use_trt_backend()
option.set_trt_input_shape("images", [1, 3, 640, 640])
return option


args = parse_arguments()

# Configure runtime, load model
runtime_option = build_option(args)
model = fd.vision.detection.YOLOv5Seg(
args.model, runtime_option=runtime_option)

# Predicting image
if args.image is None:
image = fd.utils.get_detection_test_image()
else:
image = args.image
im = cv2.imread(image)
result = model.predict(im)

# Visualization
vis_im = fd.vision.vis_detection(im, result)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")
Empty file modified fastdeploy/runtime/backends/paddle/paddle_backend.h
100644 → 100755
Empty file.
1 change: 1 addition & 0 deletions fastdeploy/vision.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "fastdeploy/vision/detection/contrib/scaledyolov4.h"
#include "fastdeploy/vision/detection/contrib/yolor.h"
#include "fastdeploy/vision/detection/contrib/yolov5/yolov5.h"
#include "fastdeploy/vision/detection/contrib/yolov5seg/yolov5seg.h"
#include "fastdeploy/vision/detection/contrib/fastestdet/fastestdet.h"
#include "fastdeploy/vision/detection/contrib/yolov5lite.h"
#include "fastdeploy/vision/detection/contrib/yolov6.h"
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/vision/common/result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void Mask::Reserve(int size) { data.reserve(size); }
void Mask::Resize(int size) { data.resize(size); }

void Mask::Clear() {
std::vector<int32_t>().swap(data);
std::vector<uint8_t>().swap(data);
std::vector<int64_t>().swap(shape);
}

Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/vision/common/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct FASTDEPLOY_DECL ClassifyResult : public BaseResult {
*/
struct FASTDEPLOY_DECL Mask : public BaseResult {
/// Mask data buffer
std::vector<int32_t> data;
std::vector<uint8_t> data;
/// Shape of mask
std::vector<int64_t> shape; // (H,W) ...
ResultType type = ResultType::MASK;
Expand Down Expand Up @@ -107,7 +107,7 @@ struct FASTDEPLOY_DECL DetectionResult : public BaseResult {
/** \brief For instance segmentation model, `masks` is the predict mask for all the deteced objects
*/
std::vector<Mask> masks;
//// Shows if the DetectionResult has mask
/// Shows if the DetectionResult has mask
bool contain_masks = false;

ResultType type = ResultType::DETECTION;
Expand Down
Loading

0 comments on commit aa6931b

Please sign in to comment.