Skip to content

Commit

Permalink
tensorrt-yolov9 (wang-xinyu#1449)
Browse files Browse the repository at this point in the history
* tensorrt-yolov9

* format change

* change format

* Update block.cpp

* format: add space

* update block.cpp: add space

* update config.h

---------

Co-authored-by: Wang Xinyu <[email protected]>
  • Loading branch information
WuxinrongY and wang-xinyu authored Mar 11, 2024
1 parent e585b15 commit e73bffc
Show file tree
Hide file tree
Showing 26 changed files with 4,581 additions and 0 deletions.
56 changes: 56 additions & 0 deletions yolov9/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
cmake_minimum_required(VERSION 3.10)

project(TRTCreater)

add_definitions(-w)
add_definitions(-std=c++11)
add_definitions(-DAPI_EXPORTS)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_BUILD_TYPE Debug)
set(CMAKE_CUDA_ARCHITECTURES 75 86 89)

MESSAGE(STATUS "operation system is ${CMAKE_SYSTEM}")
IF (CMAKE_SYSTEM_NAME MATCHES "Linux")
MESSAGE(STATUS "current platform: Linux ")
set(CUDA_COMPILER_PATH "/usr/local/cuda/bin/nvcc")
set(TENSORRT_PATH "/home/benol/Package/TensorRT-8.6.1.6")
include_directories(/usr/local/cuda/include)
link_directories(/usr/local/cuda/lib64)
link_directories(/usr/local/cuda/lib)
ELSEIF (CMAKE_SYSTEM_NAME MATCHES "Windows")
MESSAGE(STATUS "current platform: Windows")
set(CUDA_COMPILER_PATH "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.2/bin/nvcc.exe")
set(TENSORRT_PATH "D:\\Program Files\\TensorRT-8.6.1.6")
set(OpenCV_DIR "D:\\Program Files\\opencv\\build")
include_directories(${PROJECT_SOURCE_DIR}/windows)
find_package(CUDA REQUIRED)
include_directories(${CUDA_INCLUDE_DIRS})
link_directories(${CUDA_LIBRARIES})
ELSE (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
MESSAGE(STATUS "other platform: ${CMAKE_SYSTEM_PROCESSOR}")
include_directories(/usr/local/cuda/targets/aarch64-linux/include)
link_directories(/usr/local/cuda/targets/aarch64-linux/lib)
ENDIF (CMAKE_SYSTEM_NAME MATCHES "Linux")
set(CMAKE_CUDA_COMPILER ${CUDA_COMPILER_PATH})
enable_language(CUDA)

# tensorrt
include_directories(${TENSORRT_PATH}/include)
link_directories(${TENSORRT_PATH}/lib)

find_package(OpenCV)
include_directories(${OpenCV_INCLUDE_DIRS})

include_directories(${PROJECT_SOURCE_DIR}/include/)
include_directories(${PROJECT_SOURCE_DIR}/plugin/)

file(GLOB_RECURSE SRCS ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/*.cu)
file(GLOB_RECURSE PLUGIN_SRCS ${PROJECT_SOURCE_DIR}/plugin/*.cu)

# add_library(myplugins SHARED ${PLUGIN_SRCS})
add_library(myplugins SHARED ${PLUGIN_SRCS})
target_link_libraries(myplugins nvinfer cudart)

add_executable(yolov9 demo.cpp ${SRCS})
target_link_libraries(yolov9 nvinfer cudart myplugins ${OpenCV_LIBS})

106 changes: 106 additions & 0 deletions yolov9/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# yolov9

The Pytorch implementation is [WongKinYiu/yolov9](https://github.com/WongKinYiu/yolov9).

## Contributors


## Progress
- [x] YOLOv9-c:
- [x] FP32
- [x] FP16
- [x] INT8
- [x] YOLOv9-e:
- [x] FP32
- [x] FP16
- [x] INT8

## Requirements

- TensorRT 8.0+
- OpenCV 3.4.0+

## Speed Test

The speed test is done on a desktop with R7-5700G CPU and RTX 4060Ti GPU. The input size is 640x640. The FP32, FP16 and INT8 models are tested. The time only includes the inference time, not includes the pre-processing and post-processing. The time is the average of 1000 times inference.

| frame | Model | FP32 | FP16 | INT8 |
| --- | --- | --- | --- | --- |
| pytorch | YOLOv9-c | - | 15.5ms | - |
| pytorch | YOLOv9-e | - | 19.7ms | - |
| tensorrt | YOLOv9-c | 13.5ms | 4.6ms | 3.0ms |
| tensorrt | YOLOv9-e | 8.3ms | 3.2ms | 2.15ms |

YOLOv9-e is faster than YOLOv9-c in tensorrt, because the YOLOv9-e requires fewer layers of inference.
```
YOLOv9-c:
[[31, 34, 37, 16, 19, 22], 1, DualDDetect, [nc]] # [A3, A4, A5, P3, P4, P5]
YOLOv9-e:
[[35, 32, 29, 42, 45, 48], 1, DualDDetect, [nc]]
```

In DualDDetect, the A3, A4, A5, P3, P4, P5 are the output of the backbone. The first 3 layers are used for the inference of the final result.

The YOLOv9-c requires 37 layers of inference, but YOLOv9-e requires 35 layers of inference.

## How to Run, yolov9 as example

1. generate .wts from pytorch with .pt, or download .wts from model zoo

```
// download https://github.com/WongKinYiu/yolov9
cp {tensorrtx}/yolov9/gen_wts.py {yolov9}/yolov9
cd {yolov9}/yolov9
python gen_wts.py
// a file 'yolov9.wts' will be generated.
```
2. build tensorrtx/yolov9 and run


```
cd {tensorrtx}/yolov9/
// update kNumClass in config.h if your model is trained on custom dataset
mkdir build
cd build
cp {ultralytics}/ultralytics/yolov9.wts {tensorrtx}/yolov9/build
cmake ..
make
sudo ./yolov9 -s [.wts] [.engine] [c/e] // serialize model to plan file
sudo ./yolov9 -d [.engine] [image folder] // deserialize and run inference, the images in [image folder] will be processed.
// For example yolov9
sudo ./yolov9 -s yolov9-c.wts yolov9-c.engine c
sudo ./yolov9 -d yolov9-c.engine ../images
```

3. check the images generated, as follows. _zidane.jpg and _bus.jpg

4. optional, load and run the tensorrt model in python

```
// install python-tensorrt, pycuda, etc.
// ensure the yolov9.engine and libmyplugins.so have been built
python yolov9_trt.py
```


# INT8 Quantization

1. Prepare calibration images, you can randomly select 1000s images from your train set. For coco, you can also download my calibration images `coco_calib` from [GoogleDrive](https://drive.google.com/drive/folders/1s7jE9DtOngZMzJC1uL307J2MiaGwdRSI?usp=sharing) or [BaiduPan](https://pan.baidu.com/s/1GOm_-JobpyLMAqZWCDUhKg) pwd: a9wh

2. unzip it in yolov8/build

3. set the macro `USE_INT8` in config.h and change the path of calibration images in config.h, such as 'gCalibTablePath="./coco_calib/";'

4. serialize the model and test

<p align="center">
<img src="https://user-images.githubusercontent.com/15235574/78247927-4d9fac00-751e-11ea-8b1b-704a0aeb3fcf.jpg" height="360px;">
</p>

## More Information

See the readme in [home page.](https://github.com/wang-xinyu/tensorrtx)


211 changes: 211 additions & 0 deletions yolov9/demo.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
#include "config.h"
#include "model.h"
#include "cuda_utils.h"
#include "logging.h"
#include "utils.h"
#include "preprocess.h"
#include "postprocess.h"
#include <chrono>
#include <fstream>

using namespace nvinfer1;

const static int kOutputSize = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1;
static Logger gLogger;

void serialize_engine(unsigned int maxBatchSize, std::string& wts_name, std::string& sub_type, std::string& engine_name) {
// Create builder
IBuilder* builder = createInferBuilder(gLogger);
IBuilderConfig* config = builder->createBuilderConfig();

// Create model to populate the network, then set the outputs and create an engine
IHostMemory* serialized_engine = nullptr;
if (sub_type == "e") {
serialized_engine = build_engine_yolov9_e(maxBatchSize, builder, config, DataType::kFLOAT, wts_name);
} else if(sub_type == "c"){
serialized_engine = build_engine_yolov9_c(maxBatchSize, builder, config, DataType::kFLOAT, wts_name);
}
else {
return;
}
assert(serialized_engine != nullptr);

std::ofstream p(engine_name, std::ios::binary);
if (!p) {
std::cerr << "could not open plan output file" << std::endl;
assert(false);
}
p.write(reinterpret_cast<const char*>(serialized_engine->data()), serialized_engine->size());

delete config;
delete serialized_engine;
delete builder;
}

void deserialize_engine(std::string& engine_name, IRuntime** runtime, ICudaEngine** engine, IExecutionContext** context) {
std::ifstream file(engine_name, std::ios::binary);
if (!file.good()) {
std::cerr << "read " << engine_name << " error!" << std::endl;
assert(false);
}
size_t size = 0;
file.seekg(0, file.end);
size = file.tellg();
file.seekg(0, file.beg);
char* serialized_engine = new char[size];
assert(serialized_engine);
file.read(serialized_engine, size);
file.close();

*runtime = createInferRuntime(gLogger);
assert(*runtime);
*engine = (*runtime)->deserializeCudaEngine(serialized_engine, size);
assert(*engine);
*context = (*engine)->createExecutionContext();
assert(*context);
delete[] serialized_engine;
}

void prepare_buffer(ICudaEngine* engine, float** input_buffer_device, float** output_buffer_device, float** output_buffer_host) {
assert(engine->getNbBindings() == 2);
// In order to bind the buffers, we need to know the names of the input and output tensors.
// Note that indices are guaranteed to be less than IEngine::getNbBindings()
const int inputIndex = engine->getBindingIndex(kInputTensorName);
const int outputIndex = engine->getBindingIndex(kOutputTensorName);
assert(inputIndex == 0);
assert(outputIndex == 1);
// Create GPU buffers on device
CUDA_CHECK(cudaMalloc((void**)input_buffer_device, kBatchSize * 3 * kInputH * kInputW * sizeof(float)));
CUDA_CHECK(cudaMalloc((void**)output_buffer_device, kBatchSize * kOutputSize * sizeof(float)));

*output_buffer_host = new float[kBatchSize * kOutputSize];
}

void infer(IExecutionContext& context, cudaStream_t& stream, void** buffers, float* output, int batchSize) {
// infer on the batch asynchronously, and DMA output back to host
context.enqueue(batchSize, buffers, stream, nullptr);
CUDA_CHECK(cudaMemcpyAsync(output, buffers[1], batchSize * kOutputSize * sizeof(float), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
}

bool parse_args(int argc, char** argv, std::string& wts, std::string& engine, std::string& img_dir, std::string& sub_type) {
if (argc < 4) return false;
if (std::string(argv[1]) == "-s" && argc == 5) {
wts = std::string(argv[2]);
engine = std::string(argv[3]);
sub_type = std::string(argv[4]);
} else if (std::string(argv[1]) == "-d" && argc == 4) {
engine = std::string(argv[2]);
img_dir = std::string(argv[3]);
} else {
return false;
}
return true;
}

int main(int argc, char** argv) {
cudaSetDevice(kGpuId);

std::string wts_name = "";
std::string engine_name = "";
std::string img_dir;
std::string sub_type = "";
// speed test or inference
// const int speed_test_iter = 1000;
const int speed_test_iter = 1;

if (!parse_args(argc, argv, wts_name, engine_name, img_dir, sub_type)) {
std::cerr << "Arguments not right!" << std::endl;
std::cerr << "./yolov9 -s [.wts] [.engine] [c/e] // serialize model to plan file" << std::endl;
std::cerr << "./yolov9 -d [.engine] ../samples // deserialize plan file and run inference" << std::endl;
return -1;
}


// Create a model using the API directly and serialize it to a file
if (!wts_name.empty()) {
serialize_engine(kBatchSize, wts_name, sub_type, engine_name);
return 0;
}

// Deserialize the engine from file
IRuntime* runtime = nullptr;
ICudaEngine* engine = nullptr;
IExecutionContext* context = nullptr;
deserialize_engine(engine_name, &runtime, &engine, &context);
cudaStream_t stream;
CUDA_CHECK(cudaStreamCreate(&stream));

cuda_preprocess_init(kMaxInputImageSize);

// Prepare cpu and gpu buffers
float* device_buffers[2];
float* output_buffer_host = nullptr;
prepare_buffer(engine, &device_buffers[0], &device_buffers[1], &output_buffer_host);

// Read images from directory
std::vector<std::string> file_names;
if (read_files_in_dir(img_dir.c_str(), file_names) < 0) {
std::cerr << "read_files_in_dir failed." << std::endl;
return -1;
}

// batch predict
for (size_t i = 0; i < file_names.size(); i += kBatchSize) {
// Get a batch of images
std::vector<cv::Mat> img_batch;
std::vector<std::string> img_name_batch;
for (size_t j = i; j < i + kBatchSize && j < file_names.size(); j++) {
cv::Mat img = cv::imread(img_dir + "/" + file_names[j]);
img_batch.push_back(img);
img_name_batch.push_back(file_names[j]);
}

// Preprocess
cuda_batch_preprocess(img_batch, device_buffers[0], kInputW, kInputH, stream);

// Run inference
auto start = std::chrono::system_clock::now();
for (int j = 0; j < speed_test_iter; j++) {
infer(*context, stream, (void**)device_buffers, output_buffer_host, kBatchSize);
}
// infer(*context, stream, (void**)device_buffers, output_buffer_host, kBatchSize);
auto end = std::chrono::system_clock::now();
std::cout << "inference time: " << std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() / 1000.0 / speed_test_iter << "ms" << std::endl;

// NMS
std::vector<std::vector<Detection>> res_batch;
batch_nms(res_batch, output_buffer_host, img_batch.size(), kOutputSize, kConfThresh, kNmsThresh);

// Draw bounding boxes
draw_bbox(img_batch, res_batch);

// Save images
for (size_t j = 0; j < img_batch.size(); j++) {
cv::imwrite("_" + img_name_batch[j], img_batch[j]);
}
}

// Release stream and buffers
cudaStreamDestroy(stream);
CUDA_CHECK(cudaFree(device_buffers[0]));
CUDA_CHECK(cudaFree(device_buffers[1]));
delete[] output_buffer_host;
cuda_preprocess_destroy();
// Destroy the engine
delete context;
delete engine;
delete runtime;

// Print histogram of the output distribution
//std::cout << "\nOutput:\n\n";
//for (unsigned int i = 0; i < kOutputSize; i++)
//{
// std::cout << prob[i] << ", ";
// if (i % 10 == 0) std::cout << std::endl;
//}
//std::cout << std::endl;

return 0;
}

Loading

0 comments on commit e73bffc

Please sign in to comment.