diff --git a/yolov8/include/model.h b/yolov8/include/model.h index 5e153903..3e7bcbe0 100644 --- a/yolov8/include/model.h +++ b/yolov8/include/model.h @@ -3,17 +3,17 @@ #include #include -nvinfer1::IHostMemory* buildEngineYolov8n(const int& batchsize, nvinfer1::IBuilder* builder, +nvinfer1::IHostMemory* buildEngineYolov8n(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path); -nvinfer1::IHostMemory* buildEngineYolov8s(const int& batchsize, nvinfer1::IBuilder* builder, +nvinfer1::IHostMemory* buildEngineYolov8s(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path); -nvinfer1::IHostMemory* buildEngineYolov8m(const int& batchsize, nvinfer1::IBuilder* builder, +nvinfer1::IHostMemory* buildEngineYolov8m(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path); -nvinfer1::IHostMemory* buildEngineYolov8l(const int& batchsize, nvinfer1::IBuilder* builder, +nvinfer1::IHostMemory* buildEngineYolov8l(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path); -nvinfer1::IHostMemory* buildEngineYolov8x(const int& batchsize, nvinfer1::IBuilder* builder, +nvinfer1::IHostMemory* buildEngineYolov8x(nvinfer1::IBuilder* builder, nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path); diff --git a/yolov8/include/postprocess.h b/yolov8/include/postprocess.h index 66522afc..95da5648 100644 --- a/yolov8/include/postprocess.h +++ b/yolov8/include/postprocess.h @@ -1,6 +1,7 @@ #pragma once #include "types.h" +#include "NvInfer.h" #include cv::Rect get_rect(cv::Mat& img, float bbox[4]); @@ -9,6 +10,13 @@ void nms(std::vector& res, float *output, float conf_thresh, float nm void batch_nms(std::vector>& batch_res, float *output, int batch_size, int output_size, float conf_thresh, float nms_thresh = 0.5); -void draw_bbox(std::vector& img_batch, std::vector>& res_batch); +void draw_bbox(std::vector &img_batch, std::vector> &res_batch); + +void batch_process(std::vector> &res_batch, const float* decode_ptr_host, int batch_size, int bbox_element, const std::vector& img_batch); + +void process_decode_ptr_host(std::vector &res, const float* decode_ptr_host, int bbox_element, cv::Mat& img, int count); + +void cuda_decode(float* predict, int num_bboxes, float confidence_threshold,float* parray,int max_objects, cudaStream_t stream); + +void cuda_nms(float* parray, float nms_threshold, int max_objects, cudaStream_t stream); -void batch_process(std::vector>& res_batch, const float* decode_ptr_host, int batch_size, int bbox_element, const std::vector& img_batch); diff --git a/yolov8/include/preprocess.h b/yolov8/include/preprocess.h index e1aea64b..10bead97 100644 --- a/yolov8/include/preprocess.h +++ b/yolov8/include/preprocess.h @@ -6,12 +6,6 @@ #include -struct AffineMatrix { - float value[6]; -}; - -const int bbox_element = sizeof(AffineMatrix) / sizeof(float)+1; // left, top, right, bottom, confidence, class, keepflag - void cuda_preprocess_init(int max_image_size); void cuda_preprocess_destroy(); @@ -20,6 +14,3 @@ void cuda_preprocess(uint8_t *src, int src_width, int src_height, float *dst, in void cuda_batch_preprocess(std::vector &img_batch, float *dst, int dst_width, int dst_height, cudaStream_t stream); -void cuda_decode(float* predict, int num_bboxes, float confidence_threshold,float* parray,int max_objects, cudaStream_t stream); - -void cuda_nms(float* parray, float nms_threshold, int max_objects, cudaStream_t stream); \ No newline at end of file diff --git a/yolov8/include/types.h b/yolov8/include/types.h index 949ef598..574b913b 100644 --- a/yolov8/include/types.h +++ b/yolov8/include/types.h @@ -8,3 +8,8 @@ struct alignas(float) Detection { float class_id; }; +struct AffineMatrix { + float value[6]; +}; + +const int bbox_element = sizeof(AffineMatrix) / sizeof(float)+1; // left, top, right, bottom, confidence, class, keepflag diff --git a/yolov8/main.cpp b/yolov8/main.cpp index c134e0d7..74820db6 100644 --- a/yolov8/main.cpp +++ b/yolov8/main.cpp @@ -1,32 +1,33 @@ + +#include +#include +#include #include "model.h" #include "utils.h" #include "preprocess.h" #include "postprocess.h" -#include -#include #include "cuda_utils.h" -#include #include "logging.h" Logger gLogger; using namespace nvinfer1; const int kOutputSize = kMaxNumOutputBbox * sizeof(Detection) / sizeof(float) + 1; -void serialize_engine(const int &kBatchSize, std::string &wts_name, std::string &engine_name, std::string &sub_type) { +void serialize_engine(std::string &wts_name, std::string &engine_name, std::string &sub_type) { IBuilder *builder = createInferBuilder(gLogger); IBuilderConfig *config = builder->createBuilderConfig(); IHostMemory *serialized_engine = nullptr; if (sub_type == "n") { - serialized_engine = buildEngineYolov8n(kBatchSize, builder, config, DataType::kFLOAT, wts_name); + serialized_engine = buildEngineYolov8n(builder, config, DataType::kFLOAT, wts_name); } else if (sub_type == "s") { - serialized_engine = buildEngineYolov8s(kBatchSize, builder, config, DataType::kFLOAT, wts_name); + serialized_engine = buildEngineYolov8s(builder, config, DataType::kFLOAT, wts_name); } else if (sub_type == "m") { - serialized_engine = buildEngineYolov8m(kBatchSize, builder, config, DataType::kFLOAT, wts_name); + serialized_engine = buildEngineYolov8m(builder, config, DataType::kFLOAT, wts_name); } else if (sub_type == "l") { - serialized_engine = buildEngineYolov8l(kBatchSize, builder, config, DataType::kFLOAT, wts_name); + serialized_engine = buildEngineYolov8l(builder, config, DataType::kFLOAT, wts_name); } else if (sub_type == "x") { - serialized_engine = buildEngineYolov8x(kBatchSize, builder, config, DataType::kFLOAT, wts_name); + serialized_engine = buildEngineYolov8x(builder, config, DataType::kFLOAT, wts_name); } assert(serialized_engine); @@ -88,12 +89,12 @@ void prepare_buffer(ICudaEngine *engine, float **input_buffer_device, float **ou } } -void infer(IExecutionContext &context, cudaStream_t &stream, void **buffers, float *output, int batchSize, float* decode_ptr_host, float* decode_ptr_device, int batchSize_in, int model_bboxes, std::string cuda_post_process) { +void infer(IExecutionContext &context, cudaStream_t &stream, void **buffers, float *output, int batchsize, float* decode_ptr_host, float* decode_ptr_device, int model_bboxes, std::string cuda_post_process) { // infer on the batch asynchronously, and DMA output back to host auto start = std::chrono::system_clock::now(); - context.enqueue(batchSize, buffers, stream, nullptr); + context.enqueue(batchsize, buffers, stream, nullptr); if (cuda_post_process == "c") { - CUDA_CHECK(cudaMemcpyAsync(output, buffers[1], batchSize * kOutputSize * sizeof(float), cudaMemcpyDeviceToHost,stream)); + CUDA_CHECK(cudaMemcpyAsync(output, buffers[1], batchsize * kOutputSize * sizeof(float), cudaMemcpyDeviceToHost,stream)); auto end = std::chrono::system_clock::now(); std::cout << "inference time: " << std::chrono::duration_cast(end - start).count() << "ms" << std::endl; } else if (cuda_post_process == "g") { @@ -143,7 +144,7 @@ int main(int argc, char **argv) { // Create a model using the API directly and serialize it to a file if (!wts_name.empty()) { - serialize_engine(kBatchSize, wts_name, engine_name, sub_type); + serialize_engine(wts_name, engine_name, sub_type); return 0; } @@ -185,7 +186,7 @@ int main(int argc, char **argv) { // Preprocess cuda_batch_preprocess(img_batch, device_buffers[0], kInputW, kInputH, stream); // Run inference - infer(*context, stream, (void **)device_buffers, output_buffer_host, kBatchSize, decode_ptr_host, decode_ptr_device, img_batch.size(), model_bboxes, cuda_post_process); + infer(*context, stream, (void **)device_buffers, output_buffer_host, kBatchSize, decode_ptr_host, decode_ptr_device, model_bboxes, cuda_post_process); std::vector> res_batch; if (cuda_post_process == "c") { // NMS diff --git a/yolov8/src/model.cpp b/yolov8/src/model.cpp index 87d209de..54b20637 100644 --- a/yolov8/src/model.cpp +++ b/yolov8/src/model.cpp @@ -3,167 +3,160 @@ #include "calibrator.h" #include #include "config.h" -using namespace nvinfer1; - - -IHostMemory* buildEngineYolov8n(const int& batchsize, IBuilder* builder, - IBuilderConfig* config, DataType dt, const std::string& wts_path){ - std::map weightMap = loadWeights(wts_path); - INetworkDefinition* network = builder->createNetworkV2(0U); +nvinfer1::IHostMemory* buildEngineYolov8n(nvinfer1::IBuilder* builder, + nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path) { + std::map weightMap = loadWeights(wts_path); + nvinfer1::INetworkDefinition* network = builder->createNetworkV2(0U); /******************************************************************************************************* ****************************************** YOLOV8 INPUT ********************************************** *******************************************************************************************************/ - ITensor* data = network->addInput(kInputTensorName, dt, Dims3{3, kInputH, kInputW}); + nvinfer1::ITensor* data = network->addInput(kInputTensorName, dt, nvinfer1::Dims3{3, kInputH, kInputW}); assert(data); /******************************************************************************************************* ***************************************** YOLOV8 BACKBONE ******************************************** *******************************************************************************************************/ - IElementWiseLayer* conv0 = convBnSiLU(network, weightMap, *data, 16, 3, 2, 1, "model.0"); - IElementWiseLayer* conv1 = convBnSiLU(network, weightMap, *conv0->getOutput(0), 32, 3, 2, 1, "model.1"); - IElementWiseLayer* conv2 = C2F(network, weightMap, *conv1->getOutput(0), 32, 32, 1, true, 0.5, "model.2"); - IElementWiseLayer* conv3 = convBnSiLU(network, weightMap, *conv2->getOutput(0), 64, 3, 2, 1, "model.3"); - IElementWiseLayer* conv4 = C2F(network, weightMap, *conv3->getOutput(0), 64, 64, 2, true, 0.5, "model.4"); - IElementWiseLayer* conv5 = convBnSiLU(network, weightMap, *conv4->getOutput(0), 128, 3, 2, 1, "model.5"); - IElementWiseLayer* conv6 = C2F(network, weightMap, *conv5->getOutput(0), 128, 128, 2, true, 0.5, "model.6"); - IElementWiseLayer* conv7 = convBnSiLU(network, weightMap, *conv6->getOutput(0), 256, 3, 2, 1, "model.7"); - IElementWiseLayer* conv8 = C2F(network, weightMap, *conv7->getOutput(0), 256, 256, 1, true, 0.5, "model.8"); - IElementWiseLayer* conv9 = SPPF(network, weightMap, *conv8->getOutput(0), 256, 256, 5, "model.9"); + nvinfer1::IElementWiseLayer* conv0 = convBnSiLU(network, weightMap, *data, 16, 3, 2, 1, "model.0"); + nvinfer1::IElementWiseLayer* conv1 = convBnSiLU(network, weightMap, *conv0->getOutput(0), 32, 3, 2, 1, "model.1"); + nvinfer1::IElementWiseLayer* conv2 = C2F(network, weightMap, *conv1->getOutput(0), 32, 32, 1, true, 0.5, "model.2"); + nvinfer1::IElementWiseLayer* conv3 = convBnSiLU(network, weightMap, *conv2->getOutput(0), 64, 3, 2, 1, "model.3"); + nvinfer1::IElementWiseLayer* conv4 = C2F(network, weightMap, *conv3->getOutput(0), 64, 64, 2, true, 0.5, "model.4"); + nvinfer1::IElementWiseLayer* conv5 = convBnSiLU(network, weightMap, *conv4->getOutput(0), 128, 3, 2, 1, "model.5"); + nvinfer1::IElementWiseLayer* conv6 = C2F(network, weightMap, *conv5->getOutput(0), 128, 128, 2, true, 0.5, "model.6"); + nvinfer1::IElementWiseLayer* conv7 = convBnSiLU(network, weightMap, *conv6->getOutput(0), 256, 3, 2, 1, "model.7"); + nvinfer1::IElementWiseLayer* conv8 = C2F(network, weightMap, *conv7->getOutput(0), 256, 256, 1, true, 0.5, "model.8"); + nvinfer1::IElementWiseLayer* conv9 = SPPF(network, weightMap, *conv8->getOutput(0), 256, 256, 5, "model.9"); /******************************************************************************************************* ********************************************* YOLOV8 HEAD ******************************************** *******************************************************************************************************/ float scale[] = {1.0, 2.0, 2.0}; - IResizeLayer* upsample10 = network->addResize(*conv9->getOutput(0)); + nvinfer1::IResizeLayer* upsample10 = network->addResize(*conv9->getOutput(0)); assert(upsample10); - upsample10->setResizeMode(ResizeMode::kNEAREST); + upsample10->setResizeMode(nvinfer1::ResizeMode::kNEAREST); upsample10->setScales(scale, 3); - ITensor* inputTensor11[] = {upsample10->getOutput(0), conv6->getOutput(0)}; - IConcatenationLayer* cat11 = network->addConcatenation(inputTensor11, 2); + nvinfer1::ITensor* inputTensor11[] = {upsample10->getOutput(0), conv6->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat11 = network->addConcatenation(inputTensor11, 2); - IElementWiseLayer* conv12 = C2F(network, weightMap, *cat11->getOutput(0), 128, 128, 1, false, 0.5, "model.12"); + nvinfer1::IElementWiseLayer* conv12 = C2F(network, weightMap, *cat11->getOutput(0), 128, 128, 1, false, 0.5, "model.12"); - IResizeLayer* upsample13 = network->addResize(*conv12->getOutput(0)); + nvinfer1::IResizeLayer* upsample13 = network->addResize(*conv12->getOutput(0)); assert(upsample13); - upsample13->setResizeMode(ResizeMode::kNEAREST); + upsample13->setResizeMode(nvinfer1::ResizeMode::kNEAREST); upsample13->setScales(scale, 3); - ITensor* inputTensor14[] = {upsample13->getOutput(0), conv4->getOutput(0)}; - IConcatenationLayer* cat14 = network->addConcatenation(inputTensor14, 2); + nvinfer1::ITensor* inputTensor14[] = {upsample13->getOutput(0), conv4->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat14 = network->addConcatenation(inputTensor14, 2); - IElementWiseLayer* conv15 = C2F(network, weightMap, *cat14->getOutput(0), 64, 64, 1, false, 0.5, "model.15"); - IElementWiseLayer* conv16 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 64, 3, 2, 1, "model.16"); - ITensor* inputTensor17[] = {conv16->getOutput(0), conv12->getOutput(0)}; - IConcatenationLayer* cat17 = network->addConcatenation(inputTensor17, 2); - IElementWiseLayer* conv18 = C2F(network, weightMap, *cat17->getOutput(0), 128, 128, 1, false, 0.5, "model.18"); - IElementWiseLayer* conv19 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 128, 3, 2, 1, "model.19"); - ITensor* inputTensor20[] = {conv19->getOutput(0), conv9->getOutput(0)}; - IConcatenationLayer* cat20 = network->addConcatenation(inputTensor20, 2); - IElementWiseLayer* conv21 = C2F(network, weightMap, *cat20->getOutput(0), 256, 256, 1, false, 0.5, "model.21"); + nvinfer1::IElementWiseLayer* conv15 = C2F(network, weightMap, *cat14->getOutput(0), 64, 64, 1, false, 0.5, "model.15"); + nvinfer1::IElementWiseLayer* conv16 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 64, 3, 2, 1, "model.16"); + nvinfer1::ITensor* inputTensor17[] = {conv16->getOutput(0), conv12->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat17 = network->addConcatenation(inputTensor17, 2); + nvinfer1::IElementWiseLayer* conv18 = C2F(network, weightMap, *cat17->getOutput(0), 128, 128, 1, false, 0.5, "model.18"); + nvinfer1::IElementWiseLayer* conv19 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 128, 3, 2, 1, "model.19"); + nvinfer1::ITensor* inputTensor20[] = {conv19->getOutput(0), conv9->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat20 = network->addConcatenation(inputTensor20, 2); + nvinfer1::IElementWiseLayer* conv21 = C2F(network, weightMap, *cat20->getOutput(0), 256, 256, 1, false, 0.5, "model.21"); /******************************************************************************************************* ********************************************* YOLOV8 OUTPUT ****************************************** *******************************************************************************************************/ // output0 - IElementWiseLayer* conv22_cv2_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.0"); - IElementWiseLayer* conv22_cv2_0_1 = convBnSiLU(network, weightMap, *conv22_cv2_0_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.1"); - IConvolutionLayer* conv22_cv2_0_2 = network->addConvolutionNd(*conv22_cv2_0_1->getOutput(0), 64, DimsHW{1,1}, weightMap["model.22.cv2.0.2.weight"], weightMap["model.22.cv2.0.2.bias"]); - conv22_cv2_0_2->setStrideNd(DimsHW{1, 1}); - conv22_cv2_0_2->setPaddingNd(DimsHW{0, 0}); - - IElementWiseLayer* conv22_cv3_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), kNumClass, 3, 1, 1, "model.22.cv3.0.0"); - IElementWiseLayer* conv22_cv3_0_1 = convBnSiLU(network, weightMap, *conv22_cv3_0_0->getOutput(0), kNumClass, 3, 1, 1, "model.22.cv3.0.1"); - IConvolutionLayer* conv22_cv3_0_2 = network->addConvolutionNd(*conv22_cv3_0_1->getOutput(0), kNumClass, DimsHW{1,1}, weightMap["model.22.cv3.0.2.weight"], weightMap["model.22.cv3.0.2.bias"]); - conv22_cv3_0_2->setStride(DimsHW{1, 1}); - conv22_cv3_0_2->setPadding(DimsHW{0, 0}); - ITensor* inputTensor22_0[] = {conv22_cv2_0_2->getOutput(0), conv22_cv3_0_2->getOutput(0)}; - IConcatenationLayer* cat22_0 = network->addConcatenation(inputTensor22_0, 2); + + nvinfer1::IElementWiseLayer* conv22_cv2_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_0_1 = convBnSiLU(network, weightMap, *conv22_cv2_0_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_0_2 = network->addConvolutionNd(*conv22_cv2_0_1->getOutput(0), 64, nvinfer1::DimsHW{1,1}, weightMap["model.22.cv2.0.2.weight"], weightMap["model.22.cv2.0.2.bias"]); + conv22_cv2_0_2->setStrideNd(nvinfer1::DimsHW{1, 1}); + conv22_cv2_0_2->setPaddingNd(nvinfer1::DimsHW{0, 0}); + + nvinfer1::IElementWiseLayer* conv22_cv3_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 64, 3, 1, 1, "model.22.cv3.0.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_0_1 = convBnSiLU(network, weightMap, *conv22_cv3_0_0->getOutput(0), 64, 3, 1, 1, "model.22.cv3.0.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_0_2 = network->addConvolutionNd(*conv22_cv3_0_1->getOutput(0), kNumClass, nvinfer1::DimsHW{1,1}, weightMap["model.22.cv3.0.2.weight"], weightMap["model.22.cv3.0.2.bias"]); + conv22_cv3_0_2->setStride(nvinfer1::DimsHW{1, 1}); + conv22_cv3_0_2->setPadding(nvinfer1::DimsHW{0, 0}); + nvinfer1::ITensor* inputTensor22_0[] = {conv22_cv2_0_2->getOutput(0), conv22_cv3_0_2->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat22_0 = network->addConcatenation(inputTensor22_0, 2); // output1 - IElementWiseLayer* conv22_cv2_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.0"); - IElementWiseLayer* conv22_cv2_1_1 = convBnSiLU(network, weightMap, *conv22_cv2_1_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.1"); - IConvolutionLayer* conv22_cv2_1_2 = network->addConvolutionNd(*conv22_cv2_1_1->getOutput(0), 64, DimsHW{1, 1}, weightMap["model.22.cv2.1.2.weight"], weightMap["model.22.cv2.1.2.bias"]); - conv22_cv2_1_2->setStrideNd(DimsHW{1,1}); - conv22_cv2_1_2->setPaddingNd(DimsHW{0,0}); + nvinfer1::IElementWiseLayer* conv22_cv2_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_1_1 = convBnSiLU(network, weightMap, *conv22_cv2_1_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_1_2 = network->addConvolutionNd(*conv22_cv2_1_1->getOutput(0), 64, nvinfer1::DimsHW{1, 1}, weightMap["model.22.cv2.1.2.weight"], weightMap["model.22.cv2.1.2.bias"]); + conv22_cv2_1_2->setStrideNd(nvinfer1::DimsHW{1,1}); + conv22_cv2_1_2->setPaddingNd(nvinfer1::DimsHW{0,0}); - IElementWiseLayer* conv22_cv3_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), kNumClass, 3, 1, 1, "model.22.cv3.1.0"); - IElementWiseLayer* conv22_cv3_1_1 = convBnSiLU(network, weightMap, *conv22_cv3_1_0->getOutput(0), kNumClass, 3, 1, 1, "model.22.cv3.1.1"); - IConvolutionLayer* conv22_cv3_1_2 = network->addConvolutionNd(*conv22_cv3_1_1->getOutput(0), kNumClass, DimsHW{1, 1}, weightMap["model.22.cv3.1.2.weight"], weightMap["model.22.cv3.1.2.bias"]); - conv22_cv3_1_2->setStrideNd(DimsHW{1,1}); - conv22_cv3_1_2->setPaddingNd(DimsHW{0,0}); + nvinfer1::IElementWiseLayer* conv22_cv3_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 64, 3, 1, 1, "model.22.cv3.1.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_1_1 = convBnSiLU(network, weightMap, *conv22_cv3_1_0->getOutput(0), 64, 3, 1, 1, "model.22.cv3.1.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_1_2 = network->addConvolutionNd(*conv22_cv3_1_1->getOutput(0), kNumClass, nvinfer1::DimsHW{1, 1}, weightMap["model.22.cv3.1.2.weight"], weightMap["model.22.cv3.1.2.bias"]); + conv22_cv3_1_2->setStrideNd(nvinfer1::DimsHW{1,1}); + conv22_cv3_1_2->setPaddingNd(nvinfer1::DimsHW{0,0}); - ITensor* inputTensor22_1[] = {conv22_cv2_1_2->getOutput(0), conv22_cv3_1_2->getOutput(0)}; - IConcatenationLayer* cat22_1 = network->addConcatenation(inputTensor22_1, 2); + nvinfer1::ITensor* inputTensor22_1[] = {conv22_cv2_1_2->getOutput(0), conv22_cv3_1_2->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat22_1 = network->addConcatenation(inputTensor22_1, 2); // output2 - IElementWiseLayer* conv22_cv2_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.0"); - IElementWiseLayer* conv22_cv2_2_1 = convBnSiLU(network, weightMap, *conv22_cv2_2_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.1"); - IConvolutionLayer* conv22_cv2_2_2 = network->addConvolution(*conv22_cv2_2_1->getOutput(0), 64, DimsHW{1,1}, weightMap["model.22.cv2.2.2.weight"], weightMap["model.22.cv2.2.2.bias"]); + nvinfer1::IElementWiseLayer* conv22_cv2_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_2_1 = convBnSiLU(network, weightMap, *conv22_cv2_2_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_2_2 = network->addConvolution(*conv22_cv2_2_1->getOutput(0), 64, nvinfer1::DimsHW{1,1}, weightMap["model.22.cv2.2.2.weight"], weightMap["model.22.cv2.2.2.bias"]); - IElementWiseLayer* conv22_cv3_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), kNumClass, 3, 1, 1, "model.22.cv3.2.0"); - IElementWiseLayer* conv22_cv3_2_1 = convBnSiLU(network, weightMap, *conv22_cv3_2_0->getOutput(0), kNumClass, 3, 1, 1, "model.22.cv3.2.1"); - IConvolutionLayer* conv22_cv3_2_2 = network->addConvolution(*conv22_cv3_2_1->getOutput(0), kNumClass, DimsHW{1,1}, weightMap["model.22.cv3.2.2.weight"], weightMap["model.22.cv3.2.2.bias"]); + nvinfer1::IElementWiseLayer* conv22_cv3_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 64, 3, 1, 1, "model.22.cv3.2.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_2_1 = convBnSiLU(network, weightMap, *conv22_cv3_2_0->getOutput(0), 64, 3, 1, 1, "model.22.cv3.2.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_2_2 = network->addConvolution(*conv22_cv3_2_1->getOutput(0), kNumClass, nvinfer1::DimsHW{1,1}, weightMap["model.22.cv3.2.2.weight"], weightMap["model.22.cv3.2.2.bias"]); - ITensor* inputTensor22_2[] = {conv22_cv2_2_2->getOutput(0), conv22_cv3_2_2->getOutput(0)}; - IConcatenationLayer* cat22_2 = network->addConcatenation(inputTensor22_2, 2); + nvinfer1::ITensor* inputTensor22_2[] = {conv22_cv2_2_2->getOutput(0), conv22_cv3_2_2->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat22_2 = network->addConcatenation(inputTensor22_2, 2); /******************************************************************************************************* ********************************************* YOLOV8 DETECT ****************************************** *******************************************************************************************************/ - IShuffleLayer* shuffle22_0 = network->addShuffle(*cat22_0->getOutput(0)); - nvinfer1::Dims shuffle22_0_shape = shuffle22_0->getOutput(0)->getDimensions(); - int first_dim = shuffle22_0_shape.d[0]; - - - shuffle22_0->setReshapeDimensions(Dims2{first_dim, (kInputH / 8) * (kInputW / 8) }); - ISliceLayer* split22_0_0 = network->addSlice(*shuffle22_0->getOutput(0), Dims2{0, 0}, Dims2{64, (kInputH / 8) * (kInputW / 8) }, Dims2{1,1}); - ISliceLayer* split22_0_1 = network->addSlice(*shuffle22_0->getOutput(0), Dims2{64, 0}, Dims2{kNumClass, (kInputH / 8) * (kInputW / 8) }, Dims2{1,1}); - IShuffleLayer* dfl22_0 = DFL(network, weightMap, *split22_0_0->getOutput(0), 4, (kInputH / 8) * (kInputW / 8), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_0[] = {dfl22_0->getOutput(0), split22_0_1->getOutput(0)}; - IConcatenationLayer* cat22_dfl_0 = network->addConcatenation(inputTensor22_dfl_0, 2); - - IShuffleLayer* shuffle22_1 = network->addShuffle(*cat22_1->getOutput(0)); - shuffle22_1->setReshapeDimensions(Dims2{first_dim, (kInputH / 16) * (kInputW / 16) }); - ISliceLayer* split22_1_0 = network->addSlice(*shuffle22_1->getOutput(0), Dims2{0, 0}, Dims2{64, (kInputH / 16) * (kInputW / 16) }, Dims2{1,1}); - ISliceLayer* split22_1_1 = network->addSlice(*shuffle22_1->getOutput(0), Dims2{64, 0}, Dims2{ kNumClass, (kInputH / 16) * (kInputW / 16) }, Dims2{1,1}); - IShuffleLayer* dfl22_1 = DFL(network, weightMap, *split22_1_0->getOutput(0), 4, (kInputH / 16) * (kInputW / 16), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_1[] = {dfl22_1->getOutput(0), split22_1_1->getOutput(0)}; - IConcatenationLayer* cat22_dfl_1 = network->addConcatenation(inputTensor22_dfl_1, 2); - - IShuffleLayer* shuffle22_2 = network->addShuffle(*cat22_2->getOutput(0)); - shuffle22_2->setReshapeDimensions(Dims2{first_dim, (kInputH / 32) * (kInputW / 32) }); - ISliceLayer* split22_2_0 = network->addSlice(*shuffle22_2->getOutput(0), Dims2{0, 0}, Dims2{64, (kInputH / 32) * (kInputW / 32) }, Dims2{1,1}); - ISliceLayer* split22_2_1 = network->addSlice(*shuffle22_2->getOutput(0), Dims2{64, 0}, Dims2{ kNumClass, (kInputH / 32) * (kInputW / 32) }, Dims2{1,1}); - IShuffleLayer* dfl22_2 = DFL(network, weightMap, *split22_2_0->getOutput(0), 4, (kInputH / 32) * (kInputW / 32), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_2[] = {dfl22_2->getOutput(0), split22_2_1->getOutput(0)}; - IConcatenationLayer* cat22_dfl_2 = network->addConcatenation(inputTensor22_dfl_2, 2); - - IPluginV2Layer* yolo = addYoLoLayer(network, std::vector{cat22_dfl_0, cat22_dfl_1, cat22_dfl_2}); + nvinfer1::IShuffleLayer* shuffle22_0 = network->addShuffle(*cat22_0->getOutput(0)); + shuffle22_0->setReshapeDimensions(nvinfer1::Dims2{64 + kNumClass, (kInputH / 8) * (kInputW / 8) }); + + nvinfer1::ISliceLayer* split22_0_0 = network->addSlice(*shuffle22_0->getOutput(0), nvinfer1::Dims2{0, 0}, nvinfer1::Dims2{64, (kInputH / 8) * (kInputW / 8) }, nvinfer1::Dims2{1,1}); + nvinfer1::ISliceLayer* split22_0_1 = network->addSlice(*shuffle22_0->getOutput(0), nvinfer1::Dims2{64, 0}, nvinfer1::Dims2{ kNumClass, (kInputH / 8) * (kInputW / 8) }, nvinfer1::Dims2{1,1}); + nvinfer1::IShuffleLayer* dfl22_0 = DFL(network, weightMap, *split22_0_0->getOutput(0), 4, (kInputH / 8) * (kInputW / 8), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_0[] = {dfl22_0->getOutput(0), split22_0_1->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat22_dfl_0 = network->addConcatenation(inputTensor22_dfl_0, 2); + + nvinfer1::IShuffleLayer* shuffle22_1 = network->addShuffle(*cat22_1->getOutput(0)); + shuffle22_1->setReshapeDimensions(nvinfer1::Dims2{ 64 + kNumClass, (kInputH / 16) * (kInputW / 16) }); + nvinfer1::ISliceLayer* split22_1_0 = network->addSlice(*shuffle22_1->getOutput(0), nvinfer1::Dims2{0, 0}, nvinfer1::Dims2{64, (kInputH / 16) * (kInputW / 16) }, nvinfer1::Dims2{1,1}); + nvinfer1::ISliceLayer* split22_1_1 = network->addSlice(*shuffle22_1->getOutput(0), nvinfer1::Dims2{64, 0}, nvinfer1::Dims2{ kNumClass, (kInputH / 16) * (kInputW / 16) }, nvinfer1::Dims2{1,1}); + nvinfer1::IShuffleLayer* dfl22_1 = DFL(network, weightMap, *split22_1_0->getOutput(0), 4, (kInputH / 16) * (kInputW / 16), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_1[] = {dfl22_1->getOutput(0), split22_1_1->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat22_dfl_1 = network->addConcatenation(inputTensor22_dfl_1, 2); + + nvinfer1::IShuffleLayer* shuffle22_2 = network->addShuffle(*cat22_2->getOutput(0)); + shuffle22_2->setReshapeDimensions(nvinfer1::Dims2{ 64 + kNumClass, (kInputH / 32) * (kInputW / 32) }); + nvinfer1::ISliceLayer* split22_2_0 = network->addSlice(*shuffle22_2->getOutput(0), nvinfer1::Dims2{0, 0}, nvinfer1::Dims2{64, (kInputH / 32) * (kInputW / 32) }, nvinfer1::Dims2{1,1}); + nvinfer1::ISliceLayer* split22_2_1 = network->addSlice(*shuffle22_2->getOutput(0), nvinfer1::Dims2{64, 0}, nvinfer1::Dims2{ kNumClass, (kInputH / 32) * (kInputW / 32) }, nvinfer1::Dims2{1,1}); + nvinfer1::IShuffleLayer* dfl22_2 = DFL(network, weightMap, *split22_2_0->getOutput(0), 4, (kInputH / 32) * (kInputW / 32), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_2[] = {dfl22_2->getOutput(0), split22_2_1->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat22_dfl_2 = network->addConcatenation(inputTensor22_dfl_2, 2); + + nvinfer1::IPluginV2Layer* yolo = addYoLoLayer(network, std::vector{cat22_dfl_0, cat22_dfl_1, cat22_dfl_2}); yolo->getOutput(0)->setName(kOutputTensorName); network->markOutput(*yolo->getOutput(0)); - builder->setMaxBatchSize(batchsize); + builder->setMaxBatchSize(kBatchSize); config->setMaxWorkspaceSize(16* (1<<20)); #if defined(USE_FP16) - config->setFlag(BuilderFlag::kFP16); + config->setFlag(nvinfer1::BuilderFlag::kFP16); #elif defined(USE_INT8) std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; assert(builder->platformHasFastInt8()); - config->setFlag(BuilderFlag::kINT8); - Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, "./coco_calib/", "int8calib.table", kInputTensorName); + config->setFlag(nvinfer1::BuilderFlag::kINT8); + nvinfer1::IInt8EntropyCalibrator2* calibrator = new Calibrator(1, kInputW, kInputH, "../calibrator/", "int8calib.table", kInputTensorName); config->setInt8Calibrator(calibrator); - - #endif std::cout << "Building engine, please wait for a while..." << std::endl; - IHostMemory* serialized_model = builder->buildSerializedNetwork(*network, *config); + nvinfer1::IHostMemory* serialized_model = builder->buildSerializedNetwork(*network, *config); std::cout << "Build engine successfully!" << std::endl; delete network; @@ -176,167 +169,157 @@ IHostMemory* buildEngineYolov8n(const int& batchsize, IBuilder* builder, } -IHostMemory* buildEngineYolov8s(const int& batchsize, IBuilder* builder, - IBuilderConfig* config, DataType dt, const std::string& wts_path) { +nvinfer1::IHostMemory* buildEngineYolov8s(nvinfer1::IBuilder* builder, + nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path) { - std::map weightMap = loadWeights(wts_path); - INetworkDefinition* network = builder->createNetworkV2(0U); + std::map weightMap = loadWeights(wts_path); + nvinfer1::INetworkDefinition* network = builder->createNetworkV2(0U); /******************************************************************************************************* ****************************************** YOLOV8 INPUT ********************************************** *******************************************************************************************************/ - ITensor* data = network->addInput(kInputTensorName, dt, Dims3{ 3, kInputH, kInputW }); + nvinfer1::ITensor* data = network->addInput(kInputTensorName, dt, nvinfer1::Dims3{ 3, kInputH, kInputW }); assert(data); /******************************************************************************************************* ***************************************** YOLOV8 BACKBONE ******************************************** *******************************************************************************************************/ - IElementWiseLayer* conv0 = convBnSiLU(network, weightMap, *data, 32, 3, 2, 1, "model.0"); - IElementWiseLayer* conv1 = convBnSiLU(network, weightMap, *conv0->getOutput(0), 64, 3, 2, 1, "model.1"); - IElementWiseLayer* conv2 = C2F(network, weightMap, *conv1->getOutput(0), 64, 64, 1, true, 0.5, "model.2"); - IElementWiseLayer* conv3 = convBnSiLU(network, weightMap, *conv2->getOutput(0), 128, 3, 2, 1, "model.3"); - IElementWiseLayer* conv4 = C2F(network, weightMap, *conv3->getOutput(0), 128, 128, 2, true, 0.5, "model.4"); - IElementWiseLayer* conv5 = convBnSiLU(network, weightMap, *conv4->getOutput(0), 256, 3, 2, 1, "model.5"); - IElementWiseLayer* conv6 = C2F(network, weightMap, *conv5->getOutput(0), 256, 256, 2, true, 0.5, "model.6"); - IElementWiseLayer* conv7 = convBnSiLU(network, weightMap, *conv6->getOutput(0), 512, 3, 2, 1, "model.7"); - IElementWiseLayer* conv8 = C2F(network, weightMap, *conv7->getOutput(0), 512, 512, 1, true, 0.5, "model.8"); - IElementWiseLayer* conv9 = SPPF(network, weightMap, *conv8->getOutput(0), 512, 512, 5, "model.9"); + nvinfer1::IElementWiseLayer* conv0 = convBnSiLU(network, weightMap, *data, 32, 3, 2, 1, "model.0"); + nvinfer1::IElementWiseLayer* conv1 = convBnSiLU(network, weightMap, *conv0->getOutput(0), 64, 3, 2, 1, "model.1"); + nvinfer1::IElementWiseLayer* conv2 = C2F(network, weightMap, *conv1->getOutput(0), 64, 64, 1, true, 0.5, "model.2"); + nvinfer1::IElementWiseLayer* conv3 = convBnSiLU(network, weightMap, *conv2->getOutput(0), 128, 3, 2, 1, "model.3"); + nvinfer1::IElementWiseLayer* conv4 = C2F(network, weightMap, *conv3->getOutput(0), 128, 128, 2, true, 0.5, "model.4"); + nvinfer1::IElementWiseLayer* conv5 = convBnSiLU(network, weightMap, *conv4->getOutput(0), 256, 3, 2, 1, "model.5"); + nvinfer1::IElementWiseLayer* conv6 = C2F(network, weightMap, *conv5->getOutput(0), 256, 256, 2, true, 0.5, "model.6"); + nvinfer1::IElementWiseLayer* conv7 = convBnSiLU(network, weightMap, *conv6->getOutput(0), 512, 3, 2, 1, "model.7"); + nvinfer1::IElementWiseLayer* conv8 = C2F(network, weightMap, *conv7->getOutput(0), 512, 512, 1, true, 0.5, "model.8"); + nvinfer1::IElementWiseLayer* conv9 = SPPF(network, weightMap, *conv8->getOutput(0), 512, 512, 5, "model.9"); /******************************************************************************************************* ********************************************* YOLOV8 HEAD ******************************************** *******************************************************************************************************/ + float scale[] = { 1.0, 2.0, 2.0 }; - IResizeLayer* upsample10 = network->addResize(*conv9->getOutput(0)); + nvinfer1::IResizeLayer* upsample10 = network->addResize(*conv9->getOutput(0)); assert(upsample10); - upsample10->setResizeMode(ResizeMode::kNEAREST); + upsample10->setResizeMode(nvinfer1::ResizeMode::kNEAREST); upsample10->setScales(scale, 3); - ITensor* inputTensor11[] = { upsample10->getOutput(0), conv6->getOutput(0) }; - IConcatenationLayer* cat11 = network->addConcatenation(inputTensor11, 2); + nvinfer1::ITensor* inputTensor11[] = { upsample10->getOutput(0), conv6->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat11 = network->addConcatenation(inputTensor11, 2); - IElementWiseLayer* conv12 = C2F(network, weightMap, *cat11->getOutput(0), 256, 256, 1, false, 0.5, "model.12"); + nvinfer1::IElementWiseLayer* conv12 = C2F(network, weightMap, *cat11->getOutput(0), 256, 256, 1, false, 0.5, "model.12"); - IResizeLayer* upsample13 = network->addResize(*conv12->getOutput(0)); + nvinfer1::IResizeLayer* upsample13 = network->addResize(*conv12->getOutput(0)); assert(upsample13); - upsample13->setResizeMode(ResizeMode::kNEAREST); + upsample13->setResizeMode(nvinfer1::ResizeMode::kNEAREST); upsample13->setScales(scale, 3); - ITensor* inputTensor14[] = { upsample13->getOutput(0), conv4->getOutput(0) }; - IConcatenationLayer* cat14 = network->addConcatenation(inputTensor14, 2); + nvinfer1::ITensor* inputTensor14[] = { upsample13->getOutput(0), conv4->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat14 = network->addConcatenation(inputTensor14, 2); - IElementWiseLayer* conv15 = C2F(network, weightMap, *cat14->getOutput(0), 128, 128, 1, false, 0.5, "model.15"); - IElementWiseLayer* conv16 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 128, 3, 2, 1, "model.16"); - ITensor* inputTensor17[] = { conv16->getOutput(0), conv12->getOutput(0) }; - IConcatenationLayer* cat17 = network->addConcatenation(inputTensor17, 2); - IElementWiseLayer* conv18 = C2F(network, weightMap, *cat17->getOutput(0), 256, 256, 1, false, 0.5, "model.18"); - IElementWiseLayer* conv19 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 256, 3, 2, 1, "model.19"); - ITensor* inputTensor20[] = { conv19->getOutput(0), conv9->getOutput(0) }; - IConcatenationLayer* cat20 = network->addConcatenation(inputTensor20, 2); - IElementWiseLayer* conv21 = C2F(network, weightMap, *cat20->getOutput(0), 512, 512, 1, false, 0.5, "model.21"); + nvinfer1::IElementWiseLayer* conv15 = C2F(network, weightMap, *cat14->getOutput(0), 128, 128, 1, false, 0.5, "model.15"); + nvinfer1::IElementWiseLayer* conv16 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 128, 3, 2, 1, "model.16"); + nvinfer1::ITensor* inputTensor17[] = { conv16->getOutput(0), conv12->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat17 = network->addConcatenation(inputTensor17, 2); + nvinfer1::IElementWiseLayer* conv18 = C2F(network, weightMap, *cat17->getOutput(0), 256, 256, 1, false, 0.5, "model.18"); + nvinfer1::IElementWiseLayer* conv19 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 256, 3, 2, 1, "model.19"); + nvinfer1::ITensor* inputTensor20[] = { conv19->getOutput(0), conv9->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat20 = network->addConcatenation(inputTensor20, 2); + nvinfer1::IElementWiseLayer* conv21 = C2F(network, weightMap, *cat20->getOutput(0), 512, 512, 1, false, 0.5, "model.21"); /******************************************************************************************************* ********************************************* YOLOV8 OUTPUT ****************************************** *******************************************************************************************************/ // output0 - IElementWiseLayer* conv22_cv2_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.0"); - IElementWiseLayer* conv22_cv2_0_1 = convBnSiLU(network, weightMap, *conv22_cv2_0_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.1"); - IConvolutionLayer* conv22_cv2_0_2 = network->addConvolutionNd(*conv22_cv2_0_1->getOutput(0), 64, DimsHW{ 1,1 }, weightMap["model.22.cv2.0.2.weight"], weightMap["model.22.cv2.0.2.bias"]); - conv22_cv2_0_2->setStrideNd(DimsHW{ 1, 1 }); - conv22_cv2_0_2->setPaddingNd(DimsHW{ 0, 0 }); - - IElementWiseLayer* conv22_cv3_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 128, 3, 1, 1, "model.22.cv3.0.0"); - IElementWiseLayer* conv22_cv3_0_1 = convBnSiLU(network, weightMap, *conv22_cv3_0_0->getOutput(0), 128, 3, 1, 1, "model.22.cv3.0.1"); - IConvolutionLayer* conv22_cv3_0_2 = network->addConvolutionNd(*conv22_cv3_0_1->getOutput(0), kNumClass, DimsHW{ 1,1 }, weightMap["model.22.cv3.0.2.weight"], weightMap["model.22.cv3.0.2.bias"]); - conv22_cv3_0_2->setStride(DimsHW{ 1, 1 }); - conv22_cv3_0_2->setPadding(DimsHW{ 0, 0 }); - ITensor* inputTensor22_0[] = { conv22_cv2_0_2->getOutput(0), conv22_cv3_0_2->getOutput(0) }; - IConcatenationLayer* cat22_0 = network->addConcatenation(inputTensor22_0, 2); + + nvinfer1::IElementWiseLayer* conv22_cv2_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_0_1 = convBnSiLU(network, weightMap, *conv22_cv2_0_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_0_2 = network->addConvolutionNd(*conv22_cv2_0_1->getOutput(0), 64, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv2.0.2.weight"], weightMap["model.22.cv2.0.2.bias"]); + conv22_cv2_0_2->setStrideNd(nvinfer1::DimsHW{ 1, 1 }); + conv22_cv2_0_2->setPaddingNd(nvinfer1::DimsHW{ 0, 0 }); + + nvinfer1::IElementWiseLayer* conv22_cv3_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 128, 3, 1, 1, "model.22.cv3.0.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_0_1 = convBnSiLU(network, weightMap, *conv22_cv3_0_0->getOutput(0), 128, 3, 1, 1, "model.22.cv3.0.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_0_2 = network->addConvolutionNd(*conv22_cv3_0_1->getOutput(0), kNumClass, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv3.0.2.weight"], weightMap["model.22.cv3.0.2.bias"]); + conv22_cv3_0_2->setStride(nvinfer1::DimsHW{ 1, 1 }); + conv22_cv3_0_2->setPadding(nvinfer1::DimsHW{ 0, 0 }); + nvinfer1::ITensor* inputTensor22_0[] = { conv22_cv2_0_2->getOutput(0), conv22_cv3_0_2->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_0 = network->addConcatenation(inputTensor22_0, 2); // output1 - IElementWiseLayer* conv22_cv2_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.0"); - IElementWiseLayer* conv22_cv2_1_1 = convBnSiLU(network, weightMap, *conv22_cv2_1_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.1"); - IConvolutionLayer* conv22_cv2_1_2 = network->addConvolutionNd(*conv22_cv2_1_1->getOutput(0), 64, DimsHW{ 1, 1 }, weightMap["model.22.cv2.1.2.weight"], weightMap["model.22.cv2.1.2.bias"]); - conv22_cv2_1_2->setStrideNd(DimsHW{ 1,1 }); - conv22_cv2_1_2->setPaddingNd(DimsHW{ 0,0 }); + nvinfer1::IElementWiseLayer* conv22_cv2_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_1_1 = convBnSiLU(network, weightMap, *conv22_cv2_1_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_1_2 = network->addConvolutionNd(*conv22_cv2_1_1->getOutput(0), 64, nvinfer1::DimsHW{ 1, 1 }, weightMap["model.22.cv2.1.2.weight"], weightMap["model.22.cv2.1.2.bias"]); + conv22_cv2_1_2->setStrideNd(nvinfer1::DimsHW{ 1,1 }); + conv22_cv2_1_2->setPaddingNd(nvinfer1::DimsHW{ 0,0 }); - IElementWiseLayer* conv22_cv3_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 128, 3, 1, 1, "model.22.cv3.1.0"); - IElementWiseLayer* conv22_cv3_1_1 = convBnSiLU(network, weightMap, *conv22_cv3_1_0->getOutput(0), 128, 3, 1, 1, "model.22.cv3.1.1"); - IConvolutionLayer* conv22_cv3_1_2 = network->addConvolutionNd(*conv22_cv3_1_1->getOutput(0), kNumClass, DimsHW{ 1, 1 }, weightMap["model.22.cv3.1.2.weight"], weightMap["model.22.cv3.1.2.bias"]); - conv22_cv3_1_2->setStrideNd(DimsHW{ 1,1 }); - conv22_cv3_1_2->setPaddingNd(DimsHW{ 0,0 }); + nvinfer1::IElementWiseLayer* conv22_cv3_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 128, 3, 1, 1, "model.22.cv3.1.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_1_1 = convBnSiLU(network, weightMap, *conv22_cv3_1_0->getOutput(0), 128, 3, 1, 1, "model.22.cv3.1.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_1_2 = network->addConvolutionNd(*conv22_cv3_1_1->getOutput(0), kNumClass, nvinfer1::DimsHW{ 1, 1 }, weightMap["model.22.cv3.1.2.weight"], weightMap["model.22.cv3.1.2.bias"]); + conv22_cv3_1_2->setStrideNd(nvinfer1::DimsHW{ 1,1 }); + conv22_cv3_1_2->setPaddingNd(nvinfer1::DimsHW{ 0,0 }); - ITensor* inputTensor22_1[] = { conv22_cv2_1_2->getOutput(0), conv22_cv3_1_2->getOutput(0) }; - IConcatenationLayer* cat22_1 = network->addConcatenation(inputTensor22_1, 2); + nvinfer1::ITensor* inputTensor22_1[] = { conv22_cv2_1_2->getOutput(0), conv22_cv3_1_2->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_1 = network->addConcatenation(inputTensor22_1, 2); // output2 - IElementWiseLayer* conv22_cv2_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.0"); - IElementWiseLayer* conv22_cv2_2_1 = convBnSiLU(network, weightMap, *conv22_cv2_2_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.1"); - IConvolutionLayer* conv22_cv2_2_2 = network->addConvolution(*conv22_cv2_2_1->getOutput(0), 64, DimsHW{ 1,1 }, weightMap["model.22.cv2.2.2.weight"], weightMap["model.22.cv2.2.2.bias"]); + nvinfer1::IElementWiseLayer* conv22_cv2_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_2_1 = convBnSiLU(network, weightMap, *conv22_cv2_2_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_2_2 = network->addConvolution(*conv22_cv2_2_1->getOutput(0), 64, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv2.2.2.weight"], weightMap["model.22.cv2.2.2.bias"]); - IElementWiseLayer* conv22_cv3_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 128, 3, 1, 1, "model.22.cv3.2.0"); - IElementWiseLayer* conv22_cv3_2_1 = convBnSiLU(network, weightMap, *conv22_cv3_2_0->getOutput(0), 128, 3, 1, 1, "model.22.cv3.2.1"); - IConvolutionLayer* conv22_cv3_2_2 = network->addConvolution(*conv22_cv3_2_1->getOutput(0), kNumClass, DimsHW{ 1,1 }, weightMap["model.22.cv3.2.2.weight"], weightMap["model.22.cv3.2.2.bias"]); + nvinfer1::IElementWiseLayer* conv22_cv3_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 128, 3, 1, 1, "model.22.cv3.2.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_2_1 = convBnSiLU(network, weightMap, *conv22_cv3_2_0->getOutput(0), 128, 3, 1, 1, "model.22.cv3.2.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_2_2 = network->addConvolution(*conv22_cv3_2_1->getOutput(0), kNumClass, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv3.2.2.weight"], weightMap["model.22.cv3.2.2.bias"]); - ITensor* inputTensor22_2[] = { conv22_cv2_2_2->getOutput(0), conv22_cv3_2_2->getOutput(0) }; - IConcatenationLayer* cat22_2 = network->addConcatenation(inputTensor22_2, 2); + nvinfer1::ITensor* inputTensor22_2[] = { conv22_cv2_2_2->getOutput(0), conv22_cv3_2_2->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_2 = network->addConcatenation(inputTensor22_2, 2); /******************************************************************************************************* ********************************************* YOLOV8 DETECT ****************************************** *******************************************************************************************************/ - - - IShuffleLayer* shuffle22_0 = network->addShuffle(*cat22_0->getOutput(0)); - nvinfer1::Dims shuffle22_0_shape = shuffle22_0->getOutput(0)->getDimensions(); - int first_dim = shuffle22_0_shape.d[0]; - - - - shuffle22_0->setReshapeDimensions(Dims2{ first_dim, (kInputH / 8) * (kInputW / 8) }); - ISliceLayer* split22_0_0 = network->addSlice(*shuffle22_0->getOutput(0), Dims2{ 0, 0 }, Dims2{ 64, (kInputH / 8) * (kInputW / 8) }, Dims2{ 1,1 }); - ISliceLayer* split22_0_1 = network->addSlice(*shuffle22_0->getOutput(0), Dims2{ 64, 0 }, Dims2{ kNumClass, (kInputH / 8) * (kInputW / 8) }, Dims2{ 1,1 }); - IShuffleLayer* dfl22_0 = DFL(network, weightMap, *split22_0_0->getOutput(0), 4, (kInputH / 8) * (kInputW / 8), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_0[] = { dfl22_0->getOutput(0), split22_0_1->getOutput(0) }; - IConcatenationLayer* cat22_dfl_0 = network->addConcatenation(inputTensor22_dfl_0, 2); - - - - IShuffleLayer* shuffle22_1 = network->addShuffle(*cat22_1->getOutput(0)); - shuffle22_1->setReshapeDimensions(Dims2{ first_dim, (kInputH / 16) * (kInputW / 16) }); - ISliceLayer* split22_1_0 = network->addSlice(*shuffle22_1->getOutput(0), Dims2{ 0, 0 }, Dims2{ 64, (kInputH / 16) * (kInputW / 16) }, Dims2{ 1,1 }); - ISliceLayer* split22_1_1 = network->addSlice(*shuffle22_1->getOutput(0), Dims2{ 64, 0 }, Dims2{ kNumClass, (kInputH / 16) * (kInputW / 16) }, Dims2{ 1,1 }); - IShuffleLayer* dfl22_1 = DFL(network, weightMap, *split22_1_0->getOutput(0), 4, (kInputH / 16) * (kInputW / 16), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_1[] = { dfl22_1->getOutput(0), split22_1_1->getOutput(0) }; - IConcatenationLayer* cat22_dfl_1 = network->addConcatenation(inputTensor22_dfl_1, 2); - - IShuffleLayer* shuffle22_2 = network->addShuffle(*cat22_2->getOutput(0)); - - - shuffle22_2->setReshapeDimensions(Dims2{ first_dim, (kInputH / 32) * (kInputW / 32) }); - ISliceLayer* split22_2_0 = network->addSlice(*shuffle22_2->getOutput(0), Dims2{ 0, 0 }, Dims2{ 64, (kInputH / 32) * (kInputW / 32) }, Dims2{ 1,1 }); - ISliceLayer* split22_2_1 = network->addSlice(*shuffle22_2->getOutput(0), Dims2{ 64, 0 }, Dims2{ kNumClass, (kInputH / 32) * (kInputW / 32) }, Dims2{ 1,1 }); - IShuffleLayer* dfl22_2 = DFL(network, weightMap, *split22_2_0->getOutput(0), 4, (kInputH / 32) * (kInputW / 32), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_2[] = { dfl22_2->getOutput(0), split22_2_1->getOutput(0) }; - IConcatenationLayer* cat22_dfl_2 = network->addConcatenation(inputTensor22_dfl_2, 2); - - - IPluginV2Layer* yolo = addYoLoLayer(network, std::vector{cat22_dfl_0, cat22_dfl_1, cat22_dfl_2}); + nvinfer1::IShuffleLayer* shuffle22_0 = network->addShuffle(*cat22_0->getOutput(0)); + shuffle22_0->setReshapeDimensions(nvinfer1::Dims2{ 64 + kNumClass, (kInputH / 8) * (kInputW / 8) }); + nvinfer1::ISliceLayer* split22_0_0 = network->addSlice(*shuffle22_0->getOutput(0), nvinfer1::Dims2{ 0, 0 }, nvinfer1::Dims2{ 64, (kInputH / 8) * (kInputW / 8) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::ISliceLayer* split22_0_1 = network->addSlice(*shuffle22_0->getOutput(0), nvinfer1::Dims2{ 64, 0 }, nvinfer1::Dims2{ kNumClass, (kInputH / 8) * (kInputW / 8) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::IShuffleLayer* dfl22_0 = DFL(network, weightMap, *split22_0_0->getOutput(0), 4, (kInputH / 8) * (kInputW / 8), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_0[] = { dfl22_0->getOutput(0), split22_0_1->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_dfl_0 = network->addConcatenation(inputTensor22_dfl_0, 2); + + nvinfer1::IShuffleLayer* shuffle22_1 = network->addShuffle(*cat22_1->getOutput(0)); + shuffle22_1->setReshapeDimensions(nvinfer1::Dims2{ 64 + kNumClass, (kInputH / 16) * (kInputW / 16) }); + nvinfer1::ISliceLayer* split22_1_0 = network->addSlice(*shuffle22_1->getOutput(0), nvinfer1::Dims2{ 0, 0 }, nvinfer1::Dims2{ 64, (kInputH / 16) * (kInputW / 16) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::ISliceLayer* split22_1_1 = network->addSlice(*shuffle22_1->getOutput(0), nvinfer1::Dims2{ 64, 0 }, nvinfer1::Dims2{ kNumClass, (kInputH / 16) * (kInputW / 16) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::IShuffleLayer* dfl22_1 = DFL(network, weightMap, *split22_1_0->getOutput(0), 4, (kInputH / 16) * (kInputW / 16), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_1[] = { dfl22_1->getOutput(0), split22_1_1->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_dfl_1 = network->addConcatenation(inputTensor22_dfl_1, 2); + + nvinfer1::IShuffleLayer* shuffle22_2 = network->addShuffle(*cat22_2->getOutput(0)); + shuffle22_2->setReshapeDimensions(nvinfer1::Dims2{ 64 + kNumClass, (kInputH / 32) * (kInputW / 32) }); + nvinfer1::ISliceLayer* split22_2_0 = network->addSlice(*shuffle22_2->getOutput(0), nvinfer1::Dims2{ 0, 0 }, nvinfer1::Dims2{ 64, (kInputH / 32) * (kInputW / 32) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::ISliceLayer* split22_2_1 = network->addSlice(*shuffle22_2->getOutput(0), nvinfer1::Dims2{ 64, 0 }, nvinfer1::Dims2{ kNumClass, (kInputH / 32) * (kInputW / 32) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::IShuffleLayer* dfl22_2 = DFL(network, weightMap, *split22_2_0->getOutput(0), 4, (kInputH / 32) * (kInputW / 32), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_2[] = { dfl22_2->getOutput(0), split22_2_1->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_dfl_2 = network->addConcatenation(inputTensor22_dfl_2, 2); + + nvinfer1::IPluginV2Layer* yolo = addYoLoLayer(network, std::vector{cat22_dfl_0, cat22_dfl_1, cat22_dfl_2}); yolo->getOutput(0)->setName(kOutputTensorName); network->markOutput(*yolo->getOutput(0)); - builder->setMaxBatchSize(batchsize); + builder->setMaxBatchSize(kBatchSize); config->setMaxWorkspaceSize(16 * (1 << 20)); #if defined(USE_FP16) - config->setFlag(BuilderFlag::kFP16); + config->setFlag(nvinfer1::BuilderFlag::kFP16); #elif defined(USE_INT8) std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; assert(builder->platformHasFastInt8()); - config->setFlag(BuilderFlag::kINT8); - Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, "./coco_calib/", "int8calib.table", kInputTensorName); + config->setFlag(nvinfer1::BuilderFlag::kINT8); + nvinfer1::IInt8EntropyCalibrator2* calibrator = new Calibrator(1, kInputW, kInputH, "../calibrator/", "int8calib.table", kInputTensorName); config->setInt8Calibrator(calibrator); #endif std::cout << "Building engine, please wait for a while..." << std::endl; - IHostMemory* serialized_model = builder->buildSerializedNetwork(*network, *config); + nvinfer1::IHostMemory* serialized_model = builder->buildSerializedNetwork(*network, *config); std::cout << "Build engine successfully!" << std::endl; delete network; @@ -348,153 +331,150 @@ IHostMemory* buildEngineYolov8s(const int& batchsize, IBuilder* builder, } -IHostMemory* buildEngineYolov8m(const int& batchsize, IBuilder* builder, - IBuilderConfig* config, DataType dt, const std::string& wts_path) { - std::map weightMap = loadWeights(wts_path); - INetworkDefinition* network = builder->createNetworkV2(0U); +nvinfer1::IHostMemory* buildEngineYolov8m(nvinfer1::IBuilder* builder, + nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path) { + std::map weightMap = loadWeights(wts_path); + nvinfer1::INetworkDefinition* network = builder->createNetworkV2(0U); /******************************************************************************************************* ****************************************** YOLOV8 INPUT ********************************************** *******************************************************************************************************/ - ITensor* data = network->addInput(kInputTensorName, dt, Dims3{ 3, kInputH, kInputW }); + nvinfer1::ITensor* data = network->addInput(kInputTensorName, dt, nvinfer1::Dims3{ 3, kInputH, kInputW }); assert(data); /******************************************************************************************************* ***************************************** YOLOV8 BACKBONE ******************************************** *******************************************************************************************************/ - IElementWiseLayer* conv0 = convBnSiLU(network, weightMap, *data, 48, 3, 2, 1, "model.0"); - IElementWiseLayer* conv1 = convBnSiLU(network, weightMap, *conv0->getOutput(0), 96, 3, 2, 1, "model.1"); - IElementWiseLayer* conv2 = C2F(network, weightMap, *conv1->getOutput(0), 96, 96, 2, true, 0.5, "model.2"); - IElementWiseLayer* conv3 = convBnSiLU(network, weightMap, *conv2->getOutput(0), 192, 3, 2, 1, "model.3"); - IElementWiseLayer* conv4 = C2F(network, weightMap, *conv3->getOutput(0), 192, 192, 4, true, 0.5, "model.4"); - IElementWiseLayer* conv5 = convBnSiLU(network, weightMap, *conv4->getOutput(0), 384, 3, 2, 1, "model.5"); - IElementWiseLayer* conv6 = C2F(network, weightMap, *conv5->getOutput(0), 384, 384, 4, true, 0.5, "model.6"); - IElementWiseLayer* conv7 = convBnSiLU(network, weightMap, *conv6->getOutput(0), 576, 3, 2, 1, "model.7"); - IElementWiseLayer* conv8 = C2F(network, weightMap, *conv7->getOutput(0), 576, 576, 2, true, 0.5, "model.8"); - IElementWiseLayer* conv9 = SPPF(network, weightMap, *conv8->getOutput(0), 576, 576, 5, "model.9"); + nvinfer1::IElementWiseLayer* conv0 = convBnSiLU(network, weightMap, *data, 48, 3, 2, 1, "model.0"); + nvinfer1::IElementWiseLayer* conv1 = convBnSiLU(network, weightMap, *conv0->getOutput(0), 96, 3, 2, 1, "model.1"); + nvinfer1::IElementWiseLayer* conv2 = C2F(network, weightMap, *conv1->getOutput(0), 96, 96, 2, true, 0.5, "model.2"); + nvinfer1::IElementWiseLayer* conv3 = convBnSiLU(network, weightMap, *conv2->getOutput(0), 192, 3, 2, 1, "model.3"); + nvinfer1::IElementWiseLayer* conv4 = C2F(network, weightMap, *conv3->getOutput(0), 192, 192, 4, true, 0.5, "model.4"); + nvinfer1::IElementWiseLayer* conv5 = convBnSiLU(network, weightMap, *conv4->getOutput(0), 384, 3, 2, 1, "model.5"); + nvinfer1::IElementWiseLayer* conv6 = C2F(network, weightMap, *conv5->getOutput(0), 384, 384, 4, true, 0.5, "model.6"); + nvinfer1::IElementWiseLayer* conv7 = convBnSiLU(network, weightMap, *conv6->getOutput(0), 576, 3, 2, 1, "model.7"); + nvinfer1::IElementWiseLayer* conv8 = C2F(network, weightMap, *conv7->getOutput(0), 576, 576, 2, true, 0.5, "model.8"); + nvinfer1::IElementWiseLayer* conv9 = SPPF(network, weightMap, *conv8->getOutput(0), 576, 576, 5, "model.9"); /******************************************************************************************************* ********************************************* YOLOV8 HEAD ******************************************** *******************************************************************************************************/ float scale[] = { 1.0, 2.0, 2.0 }; - IResizeLayer* upsample10 = network->addResize(*conv9->getOutput(0)); - upsample10->setResizeMode(ResizeMode::kNEAREST); + nvinfer1::IResizeLayer* upsample10 = network->addResize(*conv9->getOutput(0)); + upsample10->setResizeMode(nvinfer1::ResizeMode::kNEAREST); upsample10->setScales(scale, 3); - ITensor* inputTensor11[] = { upsample10->getOutput(0), conv6->getOutput(0) }; - IConcatenationLayer* cat11 = network->addConcatenation(inputTensor11, 2); - IElementWiseLayer* conv12 = C2F(network, weightMap, *cat11->getOutput(0), 384, 384, 2, false, 0.5, "model.12"); + nvinfer1::ITensor* inputTensor11[] = { upsample10->getOutput(0), conv6->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat11 = network->addConcatenation(inputTensor11, 2); + nvinfer1::IElementWiseLayer* conv12 = C2F(network, weightMap, *cat11->getOutput(0), 384, 384, 2, false, 0.5, "model.12"); - IResizeLayer* upsample13 = network->addResize(*conv12->getOutput(0)); - upsample13->setResizeMode(ResizeMode::kNEAREST); + nvinfer1::IResizeLayer* upsample13 = network->addResize(*conv12->getOutput(0)); + upsample13->setResizeMode(nvinfer1::ResizeMode::kNEAREST); upsample13->setScales(scale, 3); - ITensor* inputTensor14[] = { upsample13->getOutput(0), conv4->getOutput(0) }; - IConcatenationLayer* cat14 = network->addConcatenation(inputTensor14, 2); - IElementWiseLayer* conv15 = C2F(network, weightMap, *cat14->getOutput(0), 192, 192, 2, false, 0.5, "model.15"); - IElementWiseLayer* conv16 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 192, 3, 2, 1, "model.16"); - ITensor* inputTensor17[] = { conv16->getOutput(0), conv12->getOutput(0) }; - IConcatenationLayer* cat17 = network->addConcatenation(inputTensor17, 2); - IElementWiseLayer* conv18 = C2F(network, weightMap, *cat17->getOutput(0), 384, 384, 2, false, 0.5, "model.18"); - IElementWiseLayer* conv19 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 384, 3, 2, 1, "model.19"); - ITensor* inputTensor20[] = { conv19->getOutput(0), conv9->getOutput(0) }; - IConcatenationLayer* cat20 = network->addConcatenation(inputTensor20, 2); - IElementWiseLayer* conv21 = C2F(network, weightMap, *cat20->getOutput(0), 576, 576, 2, false, 0.5, "model.21"); + nvinfer1::ITensor* inputTensor14[] = { upsample13->getOutput(0), conv4->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat14 = network->addConcatenation(inputTensor14, 2); + nvinfer1::IElementWiseLayer* conv15 = C2F(network, weightMap, *cat14->getOutput(0), 192, 192, 2, false, 0.5, "model.15"); + nvinfer1::IElementWiseLayer* conv16 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 192, 3, 2, 1, "model.16"); + nvinfer1::ITensor* inputTensor17[] = { conv16->getOutput(0), conv12->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat17 = network->addConcatenation(inputTensor17, 2); + nvinfer1::IElementWiseLayer* conv18 = C2F(network, weightMap, *cat17->getOutput(0), 384, 384, 2, false, 0.5, "model.18"); + nvinfer1::IElementWiseLayer* conv19 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 384, 3, 2, 1, "model.19"); + nvinfer1::ITensor* inputTensor20[] = { conv19->getOutput(0), conv9->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat20 = network->addConcatenation(inputTensor20, 2); + nvinfer1::IElementWiseLayer* conv21 = C2F(network, weightMap, *cat20->getOutput(0), 576, 576, 2, false, 0.5, "model.21"); /******************************************************************************************************* ********************************************* YOLOV8 OUTPUT ****************************************** *******************************************************************************************************/ // output0 - IElementWiseLayer* conv22_cv2_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.0"); - IElementWiseLayer* conv22_cv2_0_1 = convBnSiLU(network, weightMap, *conv22_cv2_0_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.1"); - IConvolutionLayer* conv22_cv2_0_2 = network->addConvolutionNd(*conv22_cv2_0_1->getOutput(0), 64, DimsHW{ 1,1 }, weightMap["model.22.cv2.0.2.weight"], weightMap["model.22.cv2.0.2.bias"]); - conv22_cv2_0_2->setStrideNd(DimsHW{ 1, 1 }); - conv22_cv2_0_2->setPaddingNd(DimsHW{ 0, 0 }); - - IElementWiseLayer* conv22_cv3_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 192, 3, 1, 1, "model.22.cv3.0.0"); - IElementWiseLayer* conv22_cv3_0_1 = convBnSiLU(network, weightMap, *conv22_cv3_0_0->getOutput(0), 192, 3, 1, 1, "model.22.cv3.0.1"); - IConvolutionLayer* conv22_cv3_0_2 = network->addConvolutionNd(*conv22_cv3_0_1->getOutput(0), kNumClass, DimsHW{ 1,1 }, weightMap["model.22.cv3.0.2.weight"], weightMap["model.22.cv3.0.2.bias"]); - conv22_cv3_0_2->setStride(DimsHW{ 1, 1 }); - conv22_cv3_0_2->setPadding(DimsHW{ 0, 0 }); - ITensor* inputTensor22_0[] = { conv22_cv2_0_2->getOutput(0), conv22_cv3_0_2->getOutput(0) }; - IConcatenationLayer* cat22_0 = network->addConcatenation(inputTensor22_0, 2); + nvinfer1::IElementWiseLayer* conv22_cv2_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_0_1 = convBnSiLU(network, weightMap, *conv22_cv2_0_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_0_2 = network->addConvolutionNd(*conv22_cv2_0_1->getOutput(0), 64, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv2.0.2.weight"], weightMap["model.22.cv2.0.2.bias"]); + conv22_cv2_0_2->setStrideNd(nvinfer1::DimsHW{ 1, 1 }); + conv22_cv2_0_2->setPaddingNd(nvinfer1::DimsHW{ 0, 0 }); + + nvinfer1::IElementWiseLayer* conv22_cv3_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 192, 3, 1, 1, "model.22.cv3.0.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_0_1 = convBnSiLU(network, weightMap, *conv22_cv3_0_0->getOutput(0), 192, 3, 1, 1, "model.22.cv3.0.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_0_2 = network->addConvolutionNd(*conv22_cv3_0_1->getOutput(0), kNumClass, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv3.0.2.weight"], weightMap["model.22.cv3.0.2.bias"]); + conv22_cv3_0_2->setStride(nvinfer1::DimsHW{ 1, 1 }); + conv22_cv3_0_2->setPadding(nvinfer1::DimsHW{ 0, 0 }); + nvinfer1::ITensor* inputTensor22_0[] = { conv22_cv2_0_2->getOutput(0), conv22_cv3_0_2->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_0 = network->addConcatenation(inputTensor22_0, 2); // output1 - IElementWiseLayer* conv22_cv2_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.0"); - IElementWiseLayer* conv22_cv2_1_1 = convBnSiLU(network, weightMap, *conv22_cv2_1_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.1"); - IConvolutionLayer* conv22_cv2_1_2 = network->addConvolutionNd(*conv22_cv2_1_1->getOutput(0), 64, DimsHW{ 1, 1 }, weightMap["model.22.cv2.1.2.weight"], weightMap["model.22.cv2.1.2.bias"]); - conv22_cv2_1_2->setStrideNd(DimsHW{ 1,1 }); - conv22_cv2_1_2->setPaddingNd(DimsHW{ 0,0 }); + nvinfer1::IElementWiseLayer* conv22_cv2_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_1_1 = convBnSiLU(network, weightMap, *conv22_cv2_1_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_1_2 = network->addConvolutionNd(*conv22_cv2_1_1->getOutput(0), 64, nvinfer1::DimsHW{ 1, 1 }, weightMap["model.22.cv2.1.2.weight"], weightMap["model.22.cv2.1.2.bias"]); + conv22_cv2_1_2->setStrideNd(nvinfer1::DimsHW{ 1,1 }); + conv22_cv2_1_2->setPaddingNd(nvinfer1::DimsHW{ 0,0 }); - IElementWiseLayer* conv22_cv3_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 192, 3, 1, 1, "model.22.cv3.1.0"); - IElementWiseLayer* conv22_cv3_1_1 = convBnSiLU(network, weightMap, *conv22_cv3_1_0->getOutput(0), 192, 3, 1, 1, "model.22.cv3.1.1"); - IConvolutionLayer* conv22_cv3_1_2 = network->addConvolutionNd(*conv22_cv3_1_1->getOutput(0), kNumClass, DimsHW{ 1, 1 }, weightMap["model.22.cv3.1.2.weight"], weightMap["model.22.cv3.1.2.bias"]); - conv22_cv3_1_2->setStrideNd(DimsHW{ 1,1 }); - conv22_cv3_1_2->setPaddingNd(DimsHW{ 0,0 }); + nvinfer1::IElementWiseLayer* conv22_cv3_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 192, 3, 1, 1, "model.22.cv3.1.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_1_1 = convBnSiLU(network, weightMap, *conv22_cv3_1_0->getOutput(0), 192, 3, 1, 1, "model.22.cv3.1.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_1_2 = network->addConvolutionNd(*conv22_cv3_1_1->getOutput(0), kNumClass, nvinfer1::DimsHW{ 1, 1 }, weightMap["model.22.cv3.1.2.weight"], weightMap["model.22.cv3.1.2.bias"]); + conv22_cv3_1_2->setStrideNd(nvinfer1::DimsHW{ 1,1 }); + conv22_cv3_1_2->setPaddingNd(nvinfer1::DimsHW{ 0,0 }); - ITensor* inputTensor22_1[] = { conv22_cv2_1_2->getOutput(0), conv22_cv3_1_2->getOutput(0) }; - IConcatenationLayer* cat22_1 = network->addConcatenation(inputTensor22_1, 2); + nvinfer1::ITensor* inputTensor22_1[] = { conv22_cv2_1_2->getOutput(0), conv22_cv3_1_2->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_1 = network->addConcatenation(inputTensor22_1, 2); // output2 - IElementWiseLayer* conv22_cv2_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.0"); - IElementWiseLayer* conv22_cv2_2_1 = convBnSiLU(network, weightMap, *conv22_cv2_2_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.1"); - IConvolutionLayer* conv22_cv2_2_2 = network->addConvolution(*conv22_cv2_2_1->getOutput(0), 64, DimsHW{ 1,1 }, weightMap["model.22.cv2.2.2.weight"], weightMap["model.22.cv2.2.2.bias"]); + nvinfer1::IElementWiseLayer* conv22_cv2_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_2_1 = convBnSiLU(network, weightMap, *conv22_cv2_2_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_2_2 = network->addConvolution(*conv22_cv2_2_1->getOutput(0), 64, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv2.2.2.weight"], weightMap["model.22.cv2.2.2.bias"]); - IElementWiseLayer* conv22_cv3_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 192, 3, 1, 1, "model.22.cv3.2.0"); - IElementWiseLayer* conv22_cv3_2_1 = convBnSiLU(network, weightMap, *conv22_cv3_2_0->getOutput(0), 192, 3, 1, 1, "model.22.cv3.2.1"); - IConvolutionLayer* conv22_cv3_2_2 = network->addConvolution(*conv22_cv3_2_1->getOutput(0), kNumClass, DimsHW{ 1,1 }, weightMap["model.22.cv3.2.2.weight"], weightMap["model.22.cv3.2.2.bias"]); + nvinfer1::IElementWiseLayer* conv22_cv3_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 192, 3, 1, 1, "model.22.cv3.2.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_2_1 = convBnSiLU(network, weightMap, *conv22_cv3_2_0->getOutput(0), 192, 3, 1, 1, "model.22.cv3.2.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_2_2 = network->addConvolution(*conv22_cv3_2_1->getOutput(0), kNumClass, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv3.2.2.weight"], weightMap["model.22.cv3.2.2.bias"]); - ITensor* inputTensor22_2[] = { conv22_cv2_2_2->getOutput(0), conv22_cv3_2_2->getOutput(0) }; - IConcatenationLayer* cat22_2 = network->addConcatenation(inputTensor22_2, 2); + nvinfer1::ITensor* inputTensor22_2[] = { conv22_cv2_2_2->getOutput(0), conv22_cv3_2_2->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_2 = network->addConcatenation(inputTensor22_2, 2); /******************************************************************************************************* ********************************************* YOLOV8 DETECT ****************************************** *******************************************************************************************************/ - IShuffleLayer* shuffle22_0 = network->addShuffle(*cat22_0->getOutput(0)); - nvinfer1::Dims shuffle22_0_shape = shuffle22_0->getOutput(0)->getDimensions(); - int first_dim = shuffle22_0_shape.d[0]; - - - shuffle22_0->setReshapeDimensions(Dims2{ first_dim, (kInputH / 8) * (kInputW / 8) }); - ISliceLayer* split22_0_0 = network->addSlice(*shuffle22_0->getOutput(0), Dims2{ 0, 0 }, Dims2{ 64, (kInputH / 8) * (kInputW / 8) }, Dims2{ 1,1 }); - ISliceLayer* split22_0_1 = network->addSlice(*shuffle22_0->getOutput(0), Dims2{ 64, 0 }, Dims2{ kNumClass, (kInputH / 8) * (kInputW / 8) }, Dims2{ 1,1 }); - IShuffleLayer* dfl22_0 = DFL(network, weightMap, *split22_0_0->getOutput(0), 4, (kInputH / 8) * (kInputW / 8), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_0[] = { dfl22_0->getOutput(0), split22_0_1->getOutput(0) }; - IConcatenationLayer* cat22_dfl_0 = network->addConcatenation(inputTensor22_dfl_0, 2); - - IShuffleLayer* shuffle22_1 = network->addShuffle(*cat22_1->getOutput(0)); - shuffle22_1->setReshapeDimensions(Dims2{ first_dim, (kInputH / 16) * (kInputW / 16) }); - ISliceLayer* split22_1_0 = network->addSlice(*shuffle22_1->getOutput(0), Dims2{ 0, 0 }, Dims2{ 64, (kInputH / 16) * (kInputW / 16) }, Dims2{ 1,1 }); - ISliceLayer* split22_1_1 = network->addSlice(*shuffle22_1->getOutput(0), Dims2{ 64, 0 }, Dims2{ kNumClass, (kInputH / 16) * (kInputW / 16) }, Dims2{ 1,1 }); - IShuffleLayer* dfl22_1 = DFL(network, weightMap, *split22_1_0->getOutput(0), 4, (kInputH / 16) * (kInputW / 16), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_1[] = { dfl22_1->getOutput(0), split22_1_1->getOutput(0) }; - IConcatenationLayer* cat22_dfl_1 = network->addConcatenation(inputTensor22_dfl_1, 2); - - IShuffleLayer* shuffle22_2 = network->addShuffle(*cat22_2->getOutput(0)); - shuffle22_2->setReshapeDimensions(Dims2{ first_dim, (kInputH / 32) * (kInputW / 32) }); - ISliceLayer* split22_2_0 = network->addSlice(*shuffle22_2->getOutput(0), Dims2{ 0, 0 }, Dims2{ 64, (kInputH / 32) * (kInputW / 32) }, Dims2{ 1,1 }); - ISliceLayer* split22_2_1 = network->addSlice(*shuffle22_2->getOutput(0), Dims2{ 64, 0 }, Dims2{ kNumClass, (kInputH / 32) * (kInputW / 32) }, Dims2{ 1,1 }); - IShuffleLayer* dfl22_2 = DFL(network, weightMap, *split22_2_0->getOutput(0), 4, (kInputH / 32) * (kInputW / 32), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_2[] = { dfl22_2->getOutput(0), split22_2_1->getOutput(0) }; - IConcatenationLayer* cat22_dfl_2 = network->addConcatenation(inputTensor22_dfl_2, 2); - - IPluginV2Layer* yolo = addYoLoLayer(network, std::vector{cat22_dfl_0, cat22_dfl_1, cat22_dfl_2}); + nvinfer1::IShuffleLayer* shuffle22_0 = network->addShuffle(*cat22_0->getOutput(0)); + shuffle22_0->setReshapeDimensions(nvinfer1::Dims2{ 64 + kNumClass, (kInputH / 8) * (kInputW / 8) }); + + nvinfer1::ISliceLayer* split22_0_0 = network->addSlice(*shuffle22_0->getOutput(0), nvinfer1::Dims2{ 0, 0 }, nvinfer1::Dims2{ 64, (kInputH / 8) * (kInputW / 8) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::ISliceLayer* split22_0_1 = network->addSlice(*shuffle22_0->getOutput(0), nvinfer1::Dims2{ 64, 0 }, nvinfer1::Dims2{ kNumClass, (kInputH / 8) * (kInputW / 8) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::IShuffleLayer* dfl22_0 = DFL(network, weightMap, *split22_0_0->getOutput(0), 4, (kInputH / 8) * (kInputW / 8), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_0[] = { dfl22_0->getOutput(0), split22_0_1->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_dfl_0 = network->addConcatenation(inputTensor22_dfl_0, 2); + + nvinfer1::IShuffleLayer* shuffle22_1 = network->addShuffle(*cat22_1->getOutput(0)); + shuffle22_1->setReshapeDimensions(nvinfer1::Dims2{ 64 + kNumClass, (kInputH / 16) * (kInputW / 16) }); + nvinfer1::ISliceLayer* split22_1_0 = network->addSlice(*shuffle22_1->getOutput(0), nvinfer1::Dims2{ 0, 0 }, nvinfer1::Dims2{ 64, (kInputH / 16) * (kInputW / 16) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::ISliceLayer* split22_1_1 = network->addSlice(*shuffle22_1->getOutput(0), nvinfer1::Dims2{ 64, 0 }, nvinfer1::Dims2{ kNumClass, (kInputH / 16) * (kInputW / 16) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::IShuffleLayer* dfl22_1 = DFL(network, weightMap, *split22_1_0->getOutput(0), 4, (kInputH / 16) * (kInputW / 16), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_1[] = { dfl22_1->getOutput(0), split22_1_1->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_dfl_1 = network->addConcatenation(inputTensor22_dfl_1, 2); + + nvinfer1::IShuffleLayer* shuffle22_2 = network->addShuffle(*cat22_2->getOutput(0)); + shuffle22_2->setReshapeDimensions(nvinfer1::Dims2{ 64 + kNumClass, (kInputH / 32) * (kInputW / 32) }); + nvinfer1::ISliceLayer* split22_2_0 = network->addSlice(*shuffle22_2->getOutput(0), nvinfer1::Dims2{ 0, 0 }, nvinfer1::Dims2{ 64, (kInputH / 32) * (kInputW / 32) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::ISliceLayer* split22_2_1 = network->addSlice(*shuffle22_2->getOutput(0), nvinfer1::Dims2{ 64, 0 }, nvinfer1::Dims2{ kNumClass, (kInputH / 32) * (kInputW / 32) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::IShuffleLayer* dfl22_2 = DFL(network, weightMap, *split22_2_0->getOutput(0), 4, (kInputH / 32) * (kInputW / 32), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_2[] = { dfl22_2->getOutput(0), split22_2_1->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_dfl_2 = network->addConcatenation(inputTensor22_dfl_2, 2); + + nvinfer1::IPluginV2Layer* yolo = addYoLoLayer(network, std::vector{cat22_dfl_0, cat22_dfl_1, cat22_dfl_2}); yolo->getOutput(0)->setName(kOutputTensorName); network->markOutput(*yolo->getOutput(0)); - builder->setMaxBatchSize(batchsize); + builder->setMaxBatchSize(kBatchSize); config->setMaxWorkspaceSize(16 * (1 << 20)); #if defined(USE_FP16) - config->setFlag(BuilderFlag::kFP16); + config->setFlag(nvinfer1::BuilderFlag::kFP16); #elif defined(USE_INT8) std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; assert(builder->platformHasFastInt8()); - config->setFlag(BuilderFlag::kINT8); - Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, "./coco_calib/", "int8calib.table", kInputTensorName); + config->setFlag(nvinfer1::BuilderFlag::kINT8); + nvinfer1::IInt8EntropyCalibrator2* calibrator = new Calibrator(1, kInputW, kInputH, "../calibrator/", "int8calib.table", kInputTensorName); config->setInt8Calibrator(calibrator); #endif std::cout << "Building engine, please wait for a while..." << std::endl; - IHostMemory* serialized_model = builder->buildSerializedNetwork(*network, *config); + nvinfer1::IHostMemory* serialized_model = builder->buildSerializedNetwork(*network, *config); std::cout << "Build engine successfully!" << std::endl; delete network; @@ -506,154 +486,151 @@ IHostMemory* buildEngineYolov8m(const int& batchsize, IBuilder* builder, } -IHostMemory* buildEngineYolov8l(const int& batchsize, IBuilder* builder, - IBuilderConfig* config, DataType dt, const std::string& wts_path) { - std::map weightMap = loadWeights(wts_path); - INetworkDefinition* network = builder->createNetworkV2(0U); +nvinfer1::IHostMemory* buildEngineYolov8l(nvinfer1::IBuilder* builder, + nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path) { + std::map weightMap = loadWeights(wts_path); + nvinfer1::INetworkDefinition* network = builder->createNetworkV2(0U); /******************************************************************************************************* ****************************************** YOLOV8 INPUT ********************************************** *******************************************************************************************************/ - ITensor* data = network->addInput(kInputTensorName, dt, Dims3{ 3, kInputH, kInputW }); + nvinfer1::ITensor* data = network->addInput(kInputTensorName, dt, nvinfer1::Dims3{ 3, kInputH, kInputW }); assert(data); /******************************************************************************************************* ***************************************** YOLOV8 BACKBONE ******************************************** *******************************************************************************************************/ - IElementWiseLayer* conv0 = convBnSiLU(network, weightMap, *data, 64, 3, 2, 1, "model.0"); - IElementWiseLayer* conv1 = convBnSiLU(network, weightMap, *conv0->getOutput(0), 128, 3, 2, 1, "model.1"); - IElementWiseLayer* conv2 = C2F(network, weightMap, *conv1->getOutput(0), 128, 128, 3, true, 0.5, "model.2"); - IElementWiseLayer* conv3 = convBnSiLU(network, weightMap, *conv2->getOutput(0), 256, 3, 2, 1, "model.3"); - IElementWiseLayer* conv4 = C2F(network, weightMap, *conv3->getOutput(0), 256, 256, 6, true, 0.5, "model.4"); - IElementWiseLayer* conv5 = convBnSiLU(network, weightMap, *conv4->getOutput(0), 512, 3, 2, 1, "model.5"); - IElementWiseLayer* conv6 = C2F(network, weightMap, *conv5->getOutput(0), 512, 512, 6, true, 0.5, "model.6"); - IElementWiseLayer* conv7 = convBnSiLU(network, weightMap, *conv6->getOutput(0), 512, 3, 2, 1, "model.7"); - IElementWiseLayer* conv8 = C2F(network, weightMap, *conv7->getOutput(0), 512, 512, 3, true, 0.5, "model.8"); - IElementWiseLayer* conv9 = SPPF(network, weightMap, *conv8->getOutput(0), 512, 512, 5, "model.9"); + nvinfer1::IElementWiseLayer* conv0 = convBnSiLU(network, weightMap, *data, 64, 3, 2, 1, "model.0"); + nvinfer1::IElementWiseLayer* conv1 = convBnSiLU(network, weightMap, *conv0->getOutput(0), 128, 3, 2, 1, "model.1"); + nvinfer1::IElementWiseLayer* conv2 = C2F(network, weightMap, *conv1->getOutput(0), 128, 128, 3, true, 0.5, "model.2"); + nvinfer1::IElementWiseLayer* conv3 = convBnSiLU(network, weightMap, *conv2->getOutput(0), 256, 3, 2, 1, "model.3"); + nvinfer1::IElementWiseLayer* conv4 = C2F(network, weightMap, *conv3->getOutput(0), 256, 256, 6, true, 0.5, "model.4"); + nvinfer1::IElementWiseLayer* conv5 = convBnSiLU(network, weightMap, *conv4->getOutput(0), 512, 3, 2, 1, "model.5"); + nvinfer1::IElementWiseLayer* conv6 = C2F(network, weightMap, *conv5->getOutput(0), 512, 512, 6, true, 0.5, "model.6"); + nvinfer1::IElementWiseLayer* conv7 = convBnSiLU(network, weightMap, *conv6->getOutput(0), 512, 3, 2, 1, "model.7"); + nvinfer1::IElementWiseLayer* conv8 = C2F(network, weightMap, *conv7->getOutput(0), 512, 512, 3, true, 0.5, "model.8"); + nvinfer1::IElementWiseLayer* conv9 = SPPF(network, weightMap, *conv8->getOutput(0), 512, 512, 5, "model.9"); /******************************************************************************************************* ****************************************** YOLOV8 HEAD *********************************************** *******************************************************************************************************/ float scale[] = { 1.0, 2.0, 2.0 }; - IResizeLayer* upsample10 = network->addResize(*conv9->getOutput(0)); - upsample10->setResizeMode(ResizeMode::kNEAREST); + nvinfer1::IResizeLayer* upsample10 = network->addResize(*conv9->getOutput(0)); + upsample10->setResizeMode(nvinfer1::ResizeMode::kNEAREST); upsample10->setScales(scale, 3); - ITensor* inputTensor11[] = { upsample10->getOutput(0), conv6->getOutput(0) }; - IConcatenationLayer* cat11 = network->addConcatenation(inputTensor11, 2); - IElementWiseLayer* conv12 = C2F(network, weightMap, *cat11->getOutput(0), 512, 512, 3, false, 0.5, "model.12"); + nvinfer1::ITensor* inputTensor11[] = { upsample10->getOutput(0), conv6->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat11 = network->addConcatenation(inputTensor11, 2); + nvinfer1::IElementWiseLayer* conv12 = C2F(network, weightMap, *cat11->getOutput(0), 512, 512, 3, false, 0.5, "model.12"); - IResizeLayer* upsample13 = network->addResize(*conv12->getOutput(0)); - upsample13->setResizeMode(ResizeMode::kNEAREST); + nvinfer1::IResizeLayer* upsample13 = network->addResize(*conv12->getOutput(0)); + upsample13->setResizeMode(nvinfer1::ResizeMode::kNEAREST); upsample13->setScales(scale, 3); - ITensor* inputTensor14[] = { upsample13->getOutput(0), conv4->getOutput(0) }; - IConcatenationLayer* cat14 = network->addConcatenation(inputTensor14, 2); - IElementWiseLayer* conv15 = C2F(network, weightMap, *cat14->getOutput(0), 256, 256, 3, false, 0.5, "model.15"); - IElementWiseLayer* conv16 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 256, 3, 2, 1, "model.16"); - ITensor* inputTensor17[] = { conv16->getOutput(0), conv12->getOutput(0) }; - IConcatenationLayer* cat17 = network->addConcatenation(inputTensor17, 2); - IElementWiseLayer* conv18 = C2F(network, weightMap, *cat17->getOutput(0), 512, 512, 3, false, 0.5, "model.18"); - IElementWiseLayer* conv19 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 512, 3, 2, 1, "model.19"); - ITensor* inputTensor20[] = { conv19->getOutput(0), conv9->getOutput(0) }; - IConcatenationLayer* cat20 = network->addConcatenation(inputTensor20, 2); - IElementWiseLayer* conv21 = C2F(network, weightMap, *cat20->getOutput(0), 512, 512, 3, false, 0.5, "model.21"); + nvinfer1::ITensor* inputTensor14[] = { upsample13->getOutput(0), conv4->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat14 = network->addConcatenation(inputTensor14, 2); + nvinfer1::IElementWiseLayer* conv15 = C2F(network, weightMap, *cat14->getOutput(0), 256, 256, 3, false, 0.5, "model.15"); + nvinfer1::IElementWiseLayer* conv16 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 256, 3, 2, 1, "model.16"); + nvinfer1::ITensor* inputTensor17[] = { conv16->getOutput(0), conv12->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat17 = network->addConcatenation(inputTensor17, 2); + nvinfer1::IElementWiseLayer* conv18 = C2F(network, weightMap, *cat17->getOutput(0), 512, 512, 3, false, 0.5, "model.18"); + nvinfer1::IElementWiseLayer* conv19 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 512, 3, 2, 1, "model.19"); + nvinfer1::ITensor* inputTensor20[] = { conv19->getOutput(0), conv9->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat20 = network->addConcatenation(inputTensor20, 2); + nvinfer1::IElementWiseLayer* conv21 = C2F(network, weightMap, *cat20->getOutput(0), 512, 512, 3, false, 0.5, "model.21"); /******************************************************************************************************* ********************************************* YOLOV8 OUTPUT ****************************************** *******************************************************************************************************/ // output0 - IElementWiseLayer* conv22_cv2_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.0"); - IElementWiseLayer* conv22_cv2_0_1 = convBnSiLU(network, weightMap, *conv22_cv2_0_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.1"); - IConvolutionLayer* conv22_cv2_0_2 = network->addConvolutionNd(*conv22_cv2_0_1->getOutput(0), 64, DimsHW{ 1,1 }, weightMap["model.22.cv2.0.2.weight"], weightMap["model.22.cv2.0.2.bias"]); - conv22_cv2_0_2->setStrideNd(DimsHW{ 1, 1 }); - conv22_cv2_0_2->setPaddingNd(DimsHW{ 0, 0 }); - - IElementWiseLayer* conv22_cv3_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 256, 3, 1, 1, "model.22.cv3.0.0"); - IElementWiseLayer* conv22_cv3_0_1 = convBnSiLU(network, weightMap, *conv22_cv3_0_0->getOutput(0), 256, 3, 1, 1, "model.22.cv3.0.1"); - IConvolutionLayer* conv22_cv3_0_2 = network->addConvolutionNd(*conv22_cv3_0_1->getOutput(0), kNumClass, DimsHW{ 1,1 }, weightMap["model.22.cv3.0.2.weight"], weightMap["model.22.cv3.0.2.bias"]); - conv22_cv3_0_2->setStride(DimsHW{ 1, 1 }); - conv22_cv3_0_2->setPadding(DimsHW{ 0, 0 }); - ITensor* inputTensor22_0[] = { conv22_cv2_0_2->getOutput(0), conv22_cv3_0_2->getOutput(0) }; - IConcatenationLayer* cat22_0 = network->addConcatenation(inputTensor22_0, 2); + nvinfer1::IElementWiseLayer* conv22_cv2_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_0_1 = convBnSiLU(network, weightMap, *conv22_cv2_0_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.0.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_0_2 = network->addConvolutionNd(*conv22_cv2_0_1->getOutput(0), 64, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv2.0.2.weight"], weightMap["model.22.cv2.0.2.bias"]); + conv22_cv2_0_2->setStrideNd(nvinfer1::DimsHW{ 1, 1 }); + conv22_cv2_0_2->setPaddingNd(nvinfer1::DimsHW{ 0, 0 }); + + nvinfer1::IElementWiseLayer* conv22_cv3_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 256, 3, 1, 1, "model.22.cv3.0.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_0_1 = convBnSiLU(network, weightMap, *conv22_cv3_0_0->getOutput(0), 256, 3, 1, 1, "model.22.cv3.0.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_0_2 = network->addConvolutionNd(*conv22_cv3_0_1->getOutput(0), kNumClass, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv3.0.2.weight"], weightMap["model.22.cv3.0.2.bias"]); + conv22_cv3_0_2->setStride(nvinfer1::DimsHW{ 1, 1 }); + conv22_cv3_0_2->setPadding(nvinfer1::DimsHW{ 0, 0 }); + nvinfer1::ITensor* inputTensor22_0[] = { conv22_cv2_0_2->getOutput(0), conv22_cv3_0_2->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_0 = network->addConcatenation(inputTensor22_0, 2); // output1 - IElementWiseLayer* conv22_cv2_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.0"); - IElementWiseLayer* conv22_cv2_1_1 = convBnSiLU(network, weightMap, *conv22_cv2_1_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.1"); - IConvolutionLayer* conv22_cv2_1_2 = network->addConvolutionNd(*conv22_cv2_1_1->getOutput(0), 64, DimsHW{ 1, 1 }, weightMap["model.22.cv2.1.2.weight"], weightMap["model.22.cv2.1.2.bias"]); - conv22_cv2_1_2->setStrideNd(DimsHW{ 1,1 }); - conv22_cv2_1_2->setPaddingNd(DimsHW{ 0,0 }); + nvinfer1::IElementWiseLayer* conv22_cv2_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_1_1 = convBnSiLU(network, weightMap, *conv22_cv2_1_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.1.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_1_2 = network->addConvolutionNd(*conv22_cv2_1_1->getOutput(0), 64, nvinfer1::DimsHW{ 1, 1 }, weightMap["model.22.cv2.1.2.weight"], weightMap["model.22.cv2.1.2.bias"]); + conv22_cv2_1_2->setStrideNd(nvinfer1::DimsHW{ 1,1 }); + conv22_cv2_1_2->setPaddingNd(nvinfer1::DimsHW{ 0,0 }); - IElementWiseLayer* conv22_cv3_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 256, 3, 1, 1, "model.22.cv3.1.0"); - IElementWiseLayer* conv22_cv3_1_1 = convBnSiLU(network, weightMap, *conv22_cv3_1_0->getOutput(0), 256, 3, 1, 1, "model.22.cv3.1.1"); - IConvolutionLayer* conv22_cv3_1_2 = network->addConvolutionNd(*conv22_cv3_1_1->getOutput(0), kNumClass, DimsHW{ 1, 1 }, weightMap["model.22.cv3.1.2.weight"], weightMap["model.22.cv3.1.2.bias"]); - conv22_cv3_1_2->setStrideNd(DimsHW{ 1,1 }); - conv22_cv3_1_2->setPaddingNd(DimsHW{ 0,0 }); + nvinfer1::IElementWiseLayer* conv22_cv3_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 256, 3, 1, 1, "model.22.cv3.1.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_1_1 = convBnSiLU(network, weightMap, *conv22_cv3_1_0->getOutput(0), 256, 3, 1, 1, "model.22.cv3.1.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_1_2 = network->addConvolutionNd(*conv22_cv3_1_1->getOutput(0), kNumClass, nvinfer1::DimsHW{ 1, 1 }, weightMap["model.22.cv3.1.2.weight"], weightMap["model.22.cv3.1.2.bias"]); + conv22_cv3_1_2->setStrideNd(nvinfer1::DimsHW{ 1,1 }); + conv22_cv3_1_2->setPaddingNd(nvinfer1::DimsHW{ 0,0 }); - ITensor* inputTensor22_1[] = { conv22_cv2_1_2->getOutput(0), conv22_cv3_1_2->getOutput(0) }; - IConcatenationLayer* cat22_1 = network->addConcatenation(inputTensor22_1, 2); + nvinfer1::ITensor* inputTensor22_1[] = { conv22_cv2_1_2->getOutput(0), conv22_cv3_1_2->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_1 = network->addConcatenation(inputTensor22_1, 2); // output2 - IElementWiseLayer* conv22_cv2_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.0"); - IElementWiseLayer* conv22_cv2_2_1 = convBnSiLU(network, weightMap, *conv22_cv2_2_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.1"); - IConvolutionLayer* conv22_cv2_2_2 = network->addConvolution(*conv22_cv2_2_1->getOutput(0), 64, DimsHW{ 1,1 }, weightMap["model.22.cv2.2.2.weight"], weightMap["model.22.cv2.2.2.bias"]); + nvinfer1::IElementWiseLayer* conv22_cv2_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_2_1 = convBnSiLU(network, weightMap, *conv22_cv2_2_0->getOutput(0), 64, 3, 1, 1, "model.22.cv2.2.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_2_2 = network->addConvolution(*conv22_cv2_2_1->getOutput(0), 64, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv2.2.2.weight"], weightMap["model.22.cv2.2.2.bias"]); - IElementWiseLayer* conv22_cv3_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 256, 3, 1, 1, "model.22.cv3.2.0"); - IElementWiseLayer* conv22_cv3_2_1 = convBnSiLU(network, weightMap, *conv22_cv3_2_0->getOutput(0), 256, 3, 1, 1, "model.22.cv3.2.1"); - IConvolutionLayer* conv22_cv3_2_2 = network->addConvolution(*conv22_cv3_2_1->getOutput(0), kNumClass, DimsHW{ 1,1 }, weightMap["model.22.cv3.2.2.weight"], weightMap["model.22.cv3.2.2.bias"]); + nvinfer1::IElementWiseLayer* conv22_cv3_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 256, 3, 1, 1, "model.22.cv3.2.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_2_1 = convBnSiLU(network, weightMap, *conv22_cv3_2_0->getOutput(0), 256, 3, 1, 1, "model.22.cv3.2.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_2_2 = network->addConvolution(*conv22_cv3_2_1->getOutput(0), kNumClass, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv3.2.2.weight"], weightMap["model.22.cv3.2.2.bias"]); - ITensor* inputTensor22_2[] = { conv22_cv2_2_2->getOutput(0), conv22_cv3_2_2->getOutput(0) }; - IConcatenationLayer* cat22_2 = network->addConcatenation(inputTensor22_2, 2); + nvinfer1::ITensor* inputTensor22_2[] = { conv22_cv2_2_2->getOutput(0), conv22_cv3_2_2->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_2 = network->addConcatenation(inputTensor22_2, 2); /******************************************************************************************************* ********************************************* YOLOV8 DETECT ****************************************** *******************************************************************************************************/ - IShuffleLayer* shuffle22_0 = network->addShuffle(*cat22_0->getOutput(0)); - nvinfer1::Dims shuffle22_0_shape = shuffle22_0->getOutput(0)->getDimensions(); - int first_dim = shuffle22_0_shape.d[0]; - - - shuffle22_0->setReshapeDimensions(Dims2{ first_dim, (kInputH / 8) * (kInputW / 8) }); - ISliceLayer* split22_0_0 = network->addSlice(*shuffle22_0->getOutput(0), Dims2{ 0, 0 }, Dims2{ 64, (kInputH / 8) * (kInputW / 8) }, Dims2{ 1,1 }); - ISliceLayer* split22_0_1 = network->addSlice(*shuffle22_0->getOutput(0), Dims2{ 64, 0 }, Dims2{ kNumClass, (kInputH / 8) * (kInputW / 8) }, Dims2{ 1,1 }); - IShuffleLayer* dfl22_0 = DFL(network, weightMap, *split22_0_0->getOutput(0), 4, (kInputH / 8) * (kInputW / 8), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_0[] = { dfl22_0->getOutput(0), split22_0_1->getOutput(0) }; - IConcatenationLayer* cat22_dfl_0 = network->addConcatenation(inputTensor22_dfl_0, 2); - - IShuffleLayer* shuffle22_1 = network->addShuffle(*cat22_1->getOutput(0)); - shuffle22_1->setReshapeDimensions(Dims2{ first_dim, (kInputH / 16) * (kInputW / 16) }); - ISliceLayer* split22_1_0 = network->addSlice(*shuffle22_1->getOutput(0), Dims2{ 0, 0 }, Dims2{ 64, (kInputH / 16) * (kInputW / 16) }, Dims2{ 1,1 }); - ISliceLayer* split22_1_1 = network->addSlice(*shuffle22_1->getOutput(0), Dims2{ 64, 0 }, Dims2{ kNumClass, (kInputH / 16) * (kInputW / 16) }, Dims2{ 1,1 }); - IShuffleLayer* dfl22_1 = DFL(network, weightMap, *split22_1_0->getOutput(0), 4, (kInputH / 16) * (kInputW / 16), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_1[] = { dfl22_1->getOutput(0), split22_1_1->getOutput(0) }; - IConcatenationLayer* cat22_dfl_1 = network->addConcatenation(inputTensor22_dfl_1, 2); - - IShuffleLayer* shuffle22_2 = network->addShuffle(*cat22_2->getOutput(0)); - shuffle22_2->setReshapeDimensions(Dims2{ first_dim, (kInputH / 32) * (kInputW / 32) }); - ISliceLayer* split22_2_0 = network->addSlice(*shuffle22_2->getOutput(0), Dims2{ 0, 0 }, Dims2{ 64, (kInputH / 32) * (kInputW / 32) }, Dims2{ 1,1 }); - ISliceLayer* split22_2_1 = network->addSlice(*shuffle22_2->getOutput(0), Dims2{ 64, 0 }, Dims2{ kNumClass, (kInputH / 32) * (kInputW / 32) }, Dims2{ 1,1 }); - IShuffleLayer* dfl22_2 = DFL(network, weightMap, *split22_2_0->getOutput(0), 4, (kInputH / 32) * (kInputW / 32), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_2[] = { dfl22_2->getOutput(0), split22_2_1->getOutput(0) }; - IConcatenationLayer* cat22_dfl_2 = network->addConcatenation(inputTensor22_dfl_2, 2); - - IPluginV2Layer* yolo = addYoLoLayer(network, std::vector{cat22_dfl_0, cat22_dfl_1, cat22_dfl_2}); + nvinfer1::IShuffleLayer* shuffle22_0 = network->addShuffle(*cat22_0->getOutput(0)); + shuffle22_0->setReshapeDimensions(nvinfer1::Dims2{ 64 + kNumClass, (kInputH / 8) * (kInputW / 8) }); + + nvinfer1::ISliceLayer* split22_0_0 = network->addSlice(*shuffle22_0->getOutput(0), nvinfer1::Dims2{ 0, 0 }, nvinfer1::Dims2{ 64, (kInputH / 8) * (kInputW / 8) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::ISliceLayer* split22_0_1 = network->addSlice(*shuffle22_0->getOutput(0), nvinfer1::Dims2{ 64, 0 }, nvinfer1::Dims2{ kNumClass, (kInputH / 8) * (kInputW / 8) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::IShuffleLayer* dfl22_0 = DFL(network, weightMap, *split22_0_0->getOutput(0), 4, (kInputH / 8) * (kInputW / 8), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_0[] = { dfl22_0->getOutput(0), split22_0_1->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_dfl_0 = network->addConcatenation(inputTensor22_dfl_0, 2); + + nvinfer1::IShuffleLayer* shuffle22_1 = network->addShuffle(*cat22_1->getOutput(0)); + shuffle22_1->setReshapeDimensions(nvinfer1::Dims2{ 64 + kNumClass, (kInputH / 16) * (kInputW / 16) }); + nvinfer1::ISliceLayer* split22_1_0 = network->addSlice(*shuffle22_1->getOutput(0), nvinfer1::Dims2{ 0, 0 }, nvinfer1::Dims2{ 64, (kInputH / 16) * (kInputW / 16) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::ISliceLayer* split22_1_1 = network->addSlice(*shuffle22_1->getOutput(0), nvinfer1::Dims2{ 64, 0 }, nvinfer1::Dims2{ kNumClass, (kInputH / 16) * (kInputW / 16) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::IShuffleLayer* dfl22_1 = DFL(network, weightMap, *split22_1_0->getOutput(0), 4, (kInputH / 16) * (kInputW / 16), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_1[] = { dfl22_1->getOutput(0), split22_1_1->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_dfl_1 = network->addConcatenation(inputTensor22_dfl_1, 2); + + nvinfer1::IShuffleLayer* shuffle22_2 = network->addShuffle(*cat22_2->getOutput(0)); + shuffle22_2->setReshapeDimensions(nvinfer1::Dims2{ 64 + kNumClass, (kInputH / 32) * (kInputW / 32) }); + nvinfer1::ISliceLayer* split22_2_0 = network->addSlice(*shuffle22_2->getOutput(0), nvinfer1::Dims2{ 0, 0 }, nvinfer1::Dims2{ 64, (kInputH / 32) * (kInputW / 32) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::ISliceLayer* split22_2_1 = network->addSlice(*shuffle22_2->getOutput(0), nvinfer1::Dims2{ 64, 0 }, nvinfer1::Dims2{ kNumClass, (kInputH / 32) * (kInputW / 32) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::IShuffleLayer* dfl22_2 = DFL(network, weightMap, *split22_2_0->getOutput(0), 4, (kInputH / 32) * (kInputW / 32), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_2[] = { dfl22_2->getOutput(0), split22_2_1->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_dfl_2 = network->addConcatenation(inputTensor22_dfl_2, 2); + + nvinfer1::IPluginV2Layer* yolo = addYoLoLayer(network, std::vector{cat22_dfl_0, cat22_dfl_1, cat22_dfl_2}); yolo->getOutput(0)->setName(kOutputTensorName); network->markOutput(*yolo->getOutput(0)); - builder->setMaxBatchSize(batchsize); + builder->setMaxBatchSize(kBatchSize); config->setMaxWorkspaceSize(16 * (1 << 20)); #if defined(USE_FP16) - config->setFlag(BuilderFlag::kFP16); + config->setFlag(nvinfer1::BuilderFlag::kFP16); #elif defined(USE_INT8) std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; assert(builder->platformHasFastInt8()); - config->setFlag(BuilderFlag::kINT8); - Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, "./coco_calib/", "int8calib.table", kInputTensorName); + config->setFlag(nvinfer1::BuilderFlag::kINT8); + nvinfer1::IInt8EntropyCalibrator2* calibrator = new Calibrator(1, kInputW, kInputH, "../calibrator/", "int8calib.table", kInputTensorName); config->setInt8Calibrator(calibrator); #endif std::cout << "Building engine, please wait for a while..." << std::endl; - IHostMemory* serialized_model = builder->buildSerializedNetwork(*network, *config); + nvinfer1::IHostMemory* serialized_model = builder->buildSerializedNetwork(*network, *config); std::cout << "Build engine successfully!" << std::endl; delete network; @@ -665,154 +642,151 @@ IHostMemory* buildEngineYolov8l(const int& batchsize, IBuilder* builder, } -IHostMemory* buildEngineYolov8x(const int& batchsize, IBuilder* builder, - IBuilderConfig* config, DataType dt, const std::string& wts_path) { - std::map weightMap = loadWeights(wts_path); - INetworkDefinition* network = builder->createNetworkV2(0U); +nvinfer1::IHostMemory* buildEngineYolov8x(nvinfer1::IBuilder* builder, + nvinfer1::IBuilderConfig* config, nvinfer1::DataType dt, const std::string& wts_path) { + std::map weightMap = loadWeights(wts_path); + nvinfer1::INetworkDefinition* network = builder->createNetworkV2(0U); /******************************************************************************************************* ****************************************** YOLOV8 INPUT ********************************************** *******************************************************************************************************/ - ITensor* data = network->addInput(kInputTensorName, dt, Dims3{ 3, kInputH, kInputW }); + nvinfer1::ITensor* data = network->addInput(kInputTensorName, dt, nvinfer1::Dims3{ 3, kInputH, kInputW }); assert(data); /******************************************************************************************************* ***************************************** YOLOV8 BACKBONE ******************************************** *******************************************************************************************************/ - IElementWiseLayer* conv0 = convBnSiLU(network, weightMap, *data, kNumClass, 3, 2, 1, "model.0"); - IElementWiseLayer* conv1 = convBnSiLU(network, weightMap, *conv0->getOutput(0), 160, 3, 2, 1, "model.1"); - IElementWiseLayer* conv2 = C2F(network, weightMap, *conv1->getOutput(0), 160, 160, 3, true, 0.5, "model.2"); - IElementWiseLayer* conv3 = convBnSiLU(network, weightMap, *conv2->getOutput(0), 320, 3, 2, 1, "model.3"); - IElementWiseLayer* conv4 = C2F(network, weightMap, *conv3->getOutput(0), 320, 320, 6, true, 0.5, "model.4"); - IElementWiseLayer* conv5 = convBnSiLU(network, weightMap, *conv4->getOutput(0), 640, 3, 2, 1, "model.5"); - IElementWiseLayer* conv6 = C2F(network, weightMap, *conv5->getOutput(0), 640, 640, 6, true, 0.5, "model.6"); - IElementWiseLayer* conv7 = convBnSiLU(network, weightMap, *conv6->getOutput(0), 640, 3, 2, 1, "model.7"); - IElementWiseLayer* conv8 = C2F(network, weightMap, *conv7->getOutput(0), 640, 640, 3, true, 0.5, "model.8"); - IElementWiseLayer* conv9 = SPPF(network, weightMap, *conv8->getOutput(0), 640, 640, 5, "model.9"); + nvinfer1::IElementWiseLayer* conv0 = convBnSiLU(network, weightMap, *data, 80, 3, 2, 1, "model.0"); + nvinfer1::IElementWiseLayer* conv1 = convBnSiLU(network, weightMap, *conv0->getOutput(0), 160, 3, 2, 1, "model.1"); + nvinfer1::IElementWiseLayer* conv2 = C2F(network, weightMap, *conv1->getOutput(0), 160, 160, 3, true, 0.5, "model.2"); + nvinfer1::IElementWiseLayer* conv3 = convBnSiLU(network, weightMap, *conv2->getOutput(0), 320, 3, 2, 1, "model.3"); + nvinfer1::IElementWiseLayer* conv4 = C2F(network, weightMap, *conv3->getOutput(0), 320, 320, 6, true, 0.5, "model.4"); + nvinfer1::IElementWiseLayer* conv5 = convBnSiLU(network, weightMap, *conv4->getOutput(0), 640, 3, 2, 1, "model.5"); + nvinfer1::IElementWiseLayer* conv6 = C2F(network, weightMap, *conv5->getOutput(0), 640, 640, 6, true, 0.5, "model.6"); + nvinfer1::IElementWiseLayer* conv7 = convBnSiLU(network, weightMap, *conv6->getOutput(0), 640, 3, 2, 1, "model.7"); + nvinfer1::IElementWiseLayer* conv8 = C2F(network, weightMap, *conv7->getOutput(0), 640, 640, 3, true, 0.5, "model.8"); + nvinfer1::IElementWiseLayer* conv9 = SPPF(network, weightMap, *conv8->getOutput(0), 640, 640, 5, "model.9"); /******************************************************************************************************* ****************************************** YOLOV8 HEAD *********************************************** *******************************************************************************************************/ float scale[] = { 1.0, 2.0, 2.0 }; - IResizeLayer* upsample10 = network->addResize(*conv9->getOutput(0)); - upsample10->setResizeMode(ResizeMode::kNEAREST); + nvinfer1::IResizeLayer* upsample10 = network->addResize(*conv9->getOutput(0)); + upsample10->setResizeMode(nvinfer1::ResizeMode::kNEAREST); upsample10->setScales(scale, 3); - ITensor* inputTensor11[] = { upsample10->getOutput(0), conv6->getOutput(0) }; - IConcatenationLayer* cat11 = network->addConcatenation(inputTensor11, 2); - IElementWiseLayer* conv12 = C2F(network, weightMap, *cat11->getOutput(0), 640, 640, 3, false, 0.5, "model.12"); + nvinfer1::ITensor* inputTensor11[] = { upsample10->getOutput(0), conv6->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat11 = network->addConcatenation(inputTensor11, 2); + nvinfer1::IElementWiseLayer* conv12 = C2F(network, weightMap, *cat11->getOutput(0), 640, 640, 3, false, 0.5, "model.12"); - IResizeLayer* upsample13 = network->addResize(*conv12->getOutput(0)); - upsample13->setResizeMode(ResizeMode::kNEAREST); + nvinfer1::IResizeLayer* upsample13 = network->addResize(*conv12->getOutput(0)); + upsample13->setResizeMode(nvinfer1::ResizeMode::kNEAREST); upsample13->setScales(scale, 3); - ITensor* inputTensor14[] = { upsample13->getOutput(0), conv4->getOutput(0) }; - IConcatenationLayer* cat14 = network->addConcatenation(inputTensor14, 2); - IElementWiseLayer* conv15 = C2F(network, weightMap, *cat14->getOutput(0), 320, 320, 3, false, 0.5, "model.15"); - IElementWiseLayer* conv16 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 320, 3, 2, 1, "model.16"); - ITensor* inputTensor17[] = { conv16->getOutput(0), conv12->getOutput(0) }; - IConcatenationLayer* cat17 = network->addConcatenation(inputTensor17, 2); - IElementWiseLayer* conv18 = C2F(network, weightMap, *cat17->getOutput(0), 640, 640, 3, false, 0.5, "model.18"); - IElementWiseLayer* conv19 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 640, 3, 2, 1, "model.19"); - ITensor* inputTensor20[] = { conv19->getOutput(0), conv9->getOutput(0) }; - IConcatenationLayer* cat20 = network->addConcatenation(inputTensor20, 2); - IElementWiseLayer* conv21 = C2F(network, weightMap, *cat20->getOutput(0), 640, 640, 3, false, 0.5, "model.21"); + nvinfer1::ITensor* inputTensor14[] = { upsample13->getOutput(0), conv4->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat14 = network->addConcatenation(inputTensor14, 2); + nvinfer1::IElementWiseLayer* conv15 = C2F(network, weightMap, *cat14->getOutput(0), 320, 320, 3, false, 0.5, "model.15"); + nvinfer1::IElementWiseLayer* conv16 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 320, 3, 2, 1, "model.16"); + nvinfer1::ITensor* inputTensor17[] = { conv16->getOutput(0), conv12->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat17 = network->addConcatenation(inputTensor17, 2); + nvinfer1::IElementWiseLayer* conv18 = C2F(network, weightMap, *cat17->getOutput(0), 640, 640, 3, false, 0.5, "model.18"); + nvinfer1::IElementWiseLayer* conv19 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 640, 3, 2, 1, "model.19"); + nvinfer1::ITensor* inputTensor20[] = { conv19->getOutput(0), conv9->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat20 = network->addConcatenation(inputTensor20, 2); + nvinfer1::IElementWiseLayer* conv21 = C2F(network, weightMap, *cat20->getOutput(0), 640, 640, 3, false, 0.5, "model.21"); /******************************************************************************************************* ********************************************* YOLOV8 OUTPUT ****************************************** *******************************************************************************************************/ // output0 - IElementWiseLayer* conv22_cv2_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), kNumClass, 3, 1, 1, "model.22.cv2.0.0"); - IElementWiseLayer* conv22_cv2_0_1 = convBnSiLU(network, weightMap, *conv22_cv2_0_0->getOutput(0), kNumClass, 3, 1, 1, "model.22.cv2.0.1"); - IConvolutionLayer* conv22_cv2_0_2 = network->addConvolutionNd(*conv22_cv2_0_1->getOutput(0), 64, DimsHW{ 1,1 }, weightMap["model.22.cv2.0.2.weight"], weightMap["model.22.cv2.0.2.bias"]); - conv22_cv2_0_2->setStrideNd(DimsHW{ 1, 1 }); - conv22_cv2_0_2->setPaddingNd(DimsHW{ 0, 0 }); - - IElementWiseLayer* conv22_cv3_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 320, 3, 1, 1, "model.22.cv3.0.0"); - IElementWiseLayer* conv22_cv3_0_1 = convBnSiLU(network, weightMap, *conv22_cv3_0_0->getOutput(0), 320, 3, 1, 1, "model.22.cv3.0.1"); - IConvolutionLayer* conv22_cv3_0_2 = network->addConvolutionNd(*conv22_cv3_0_1->getOutput(0), kNumClass, DimsHW{ 1,1 }, weightMap["model.22.cv3.0.2.weight"], weightMap["model.22.cv3.0.2.bias"]); - conv22_cv3_0_2->setStride(DimsHW{ 1, 1 }); - conv22_cv3_0_2->setPadding(DimsHW{ 0, 0 }); - ITensor* inputTensor22_0[] = { conv22_cv2_0_2->getOutput(0), conv22_cv3_0_2->getOutput(0) }; - IConcatenationLayer* cat22_0 = network->addConcatenation(inputTensor22_0, 2); + nvinfer1::IElementWiseLayer* conv22_cv2_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 80, 3, 1, 1, "model.22.cv2.0.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_0_1 = convBnSiLU(network, weightMap, *conv22_cv2_0_0->getOutput(0), 80, 3, 1, 1, "model.22.cv2.0.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_0_2 = network->addConvolutionNd(*conv22_cv2_0_1->getOutput(0), 64, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv2.0.2.weight"], weightMap["model.22.cv2.0.2.bias"]); + conv22_cv2_0_2->setStrideNd(nvinfer1::DimsHW{ 1, 1 }); + conv22_cv2_0_2->setPaddingNd(nvinfer1::DimsHW{ 0, 0 }); + + nvinfer1::IElementWiseLayer* conv22_cv3_0_0 = convBnSiLU(network, weightMap, *conv15->getOutput(0), 320, 3, 1, 1, "model.22.cv3.0.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_0_1 = convBnSiLU(network, weightMap, *conv22_cv3_0_0->getOutput(0), 320, 3, 1, 1, "model.22.cv3.0.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_0_2 = network->addConvolutionNd(*conv22_cv3_0_1->getOutput(0), kNumClass, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv3.0.2.weight"], weightMap["model.22.cv3.0.2.bias"]); + conv22_cv3_0_2->setStride(nvinfer1::DimsHW{ 1, 1 }); + conv22_cv3_0_2->setPadding(nvinfer1::DimsHW{ 0, 0 }); + nvinfer1::ITensor* inputTensor22_0[] = { conv22_cv2_0_2->getOutput(0), conv22_cv3_0_2->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_0 = network->addConcatenation(inputTensor22_0, 2); // output1 - IElementWiseLayer* conv22_cv2_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), kNumClass, 3, 1, 1, "model.22.cv2.1.0"); - IElementWiseLayer* conv22_cv2_1_1 = convBnSiLU(network, weightMap, *conv22_cv2_1_0->getOutput(0), kNumClass, 3, 1, 1, "model.22.cv2.1.1"); - IConvolutionLayer* conv22_cv2_1_2 = network->addConvolutionNd(*conv22_cv2_1_1->getOutput(0), 64, DimsHW{ 1, 1 }, weightMap["model.22.cv2.1.2.weight"], weightMap["model.22.cv2.1.2.bias"]); - conv22_cv2_1_2->setStrideNd(DimsHW{ 1,1 }); - conv22_cv2_1_2->setPaddingNd(DimsHW{ 0,0 }); + nvinfer1::IElementWiseLayer* conv22_cv2_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 80, 3, 1, 1, "model.22.cv2.1.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_1_1 = convBnSiLU(network, weightMap, *conv22_cv2_1_0->getOutput(0), 80, 3, 1, 1, "model.22.cv2.1.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_1_2 = network->addConvolutionNd(*conv22_cv2_1_1->getOutput(0), 64, nvinfer1::DimsHW{ 1, 1 }, weightMap["model.22.cv2.1.2.weight"], weightMap["model.22.cv2.1.2.bias"]); + conv22_cv2_1_2->setStrideNd(nvinfer1::DimsHW{ 1,1 }); + conv22_cv2_1_2->setPaddingNd(nvinfer1::DimsHW{ 0,0 }); - IElementWiseLayer* conv22_cv3_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 320, 3, 1, 1, "model.22.cv3.1.0"); - IElementWiseLayer* conv22_cv3_1_1 = convBnSiLU(network, weightMap, *conv22_cv3_1_0->getOutput(0), 320, 3, 1, 1, "model.22.cv3.1.1"); - IConvolutionLayer* conv22_cv3_1_2 = network->addConvolutionNd(*conv22_cv3_1_1->getOutput(0), kNumClass, DimsHW{ 1, 1 }, weightMap["model.22.cv3.1.2.weight"], weightMap["model.22.cv3.1.2.bias"]); - conv22_cv3_1_2->setStrideNd(DimsHW{ 1,1 }); - conv22_cv3_1_2->setPaddingNd(DimsHW{ 0,0 }); + nvinfer1::IElementWiseLayer* conv22_cv3_1_0 = convBnSiLU(network, weightMap, *conv18->getOutput(0), 320, 3, 1, 1, "model.22.cv3.1.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_1_1 = convBnSiLU(network, weightMap, *conv22_cv3_1_0->getOutput(0), 320, 3, 1, 1, "model.22.cv3.1.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_1_2 = network->addConvolutionNd(*conv22_cv3_1_1->getOutput(0), kNumClass, nvinfer1::DimsHW{ 1, 1 }, weightMap["model.22.cv3.1.2.weight"], weightMap["model.22.cv3.1.2.bias"]); + conv22_cv3_1_2->setStrideNd(nvinfer1::DimsHW{ 1,1 }); + conv22_cv3_1_2->setPaddingNd(nvinfer1::DimsHW{ 0,0 }); - ITensor* inputTensor22_1[] = { conv22_cv2_1_2->getOutput(0), conv22_cv3_1_2->getOutput(0) }; - IConcatenationLayer* cat22_1 = network->addConcatenation(inputTensor22_1, 2); + nvinfer1::ITensor* inputTensor22_1[] = { conv22_cv2_1_2->getOutput(0), conv22_cv3_1_2->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_1 = network->addConcatenation(inputTensor22_1, 2); // output2 - IElementWiseLayer* conv22_cv2_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), kNumClass, 3, 1, 1, "model.22.cv2.2.0"); - IElementWiseLayer* conv22_cv2_2_1 = convBnSiLU(network, weightMap, *conv22_cv2_2_0->getOutput(0), kNumClass, 3, 1, 1, "model.22.cv2.2.1"); - IConvolutionLayer* conv22_cv2_2_2 = network->addConvolution(*conv22_cv2_2_1->getOutput(0), 64, DimsHW{ 1,1 }, weightMap["model.22.cv2.2.2.weight"], weightMap["model.22.cv2.2.2.bias"]); + nvinfer1::IElementWiseLayer* conv22_cv2_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 80, 3, 1, 1, "model.22.cv2.2.0"); + nvinfer1::IElementWiseLayer* conv22_cv2_2_1 = convBnSiLU(network, weightMap, *conv22_cv2_2_0->getOutput(0), 80, 3, 1, 1, "model.22.cv2.2.1"); + nvinfer1::IConvolutionLayer* conv22_cv2_2_2 = network->addConvolution(*conv22_cv2_2_1->getOutput(0), 64, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv2.2.2.weight"], weightMap["model.22.cv2.2.2.bias"]); - IElementWiseLayer* conv22_cv3_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 320, 3, 1, 1, "model.22.cv3.2.0"); - IElementWiseLayer* conv22_cv3_2_1 = convBnSiLU(network, weightMap, *conv22_cv3_2_0->getOutput(0), 320, 3, 1, 1, "model.22.cv3.2.1"); - IConvolutionLayer* conv22_cv3_2_2 = network->addConvolution(*conv22_cv3_2_1->getOutput(0), kNumClass, DimsHW{ 1,1 }, weightMap["model.22.cv3.2.2.weight"], weightMap["model.22.cv3.2.2.bias"]); + nvinfer1::IElementWiseLayer* conv22_cv3_2_0 = convBnSiLU(network, weightMap, *conv21->getOutput(0), 320, 3, 1, 1, "model.22.cv3.2.0"); + nvinfer1::IElementWiseLayer* conv22_cv3_2_1 = convBnSiLU(network, weightMap, *conv22_cv3_2_0->getOutput(0), 320, 3, 1, 1, "model.22.cv3.2.1"); + nvinfer1::IConvolutionLayer* conv22_cv3_2_2 = network->addConvolution(*conv22_cv3_2_1->getOutput(0), kNumClass, nvinfer1::DimsHW{ 1,1 }, weightMap["model.22.cv3.2.2.weight"], weightMap["model.22.cv3.2.2.bias"]); - ITensor* inputTensor22_2[] = { conv22_cv2_2_2->getOutput(0), conv22_cv3_2_2->getOutput(0) }; - IConcatenationLayer* cat22_2 = network->addConcatenation(inputTensor22_2, 2); + nvinfer1::ITensor* inputTensor22_2[] = { conv22_cv2_2_2->getOutput(0), conv22_cv3_2_2->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_2 = network->addConcatenation(inputTensor22_2, 2); /******************************************************************************************************* ********************************************* YOLOV8 DETECT ****************************************** *******************************************************************************************************/ - IShuffleLayer* shuffle22_0 = network->addShuffle(*cat22_0->getOutput(0)); - nvinfer1::Dims shuffle22_0_shape = shuffle22_0->getOutput(0)->getDimensions(); - int first_dim = shuffle22_0_shape.d[0]; - - - shuffle22_0->setReshapeDimensions(Dims2{ first_dim, (kInputH / 8) * (kInputW / 8) }); - ISliceLayer* split22_0_0 = network->addSlice(*shuffle22_0->getOutput(0), Dims2{ 0, 0 }, Dims2{ 64, (kInputH / 8) * (kInputW / 8) }, Dims2{ 1,1 }); - ISliceLayer* split22_0_1 = network->addSlice(*shuffle22_0->getOutput(0), Dims2{ 64, 0 }, Dims2{ kNumClass, (kInputH / 8) * (kInputW / 8) }, Dims2{ 1,1 }); - IShuffleLayer* dfl22_0 = DFL(network, weightMap, *split22_0_0->getOutput(0), 4, (kInputH / 8) * (kInputW / 8), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_0[] = { dfl22_0->getOutput(0), split22_0_1->getOutput(0) }; - IConcatenationLayer* cat22_dfl_0 = network->addConcatenation(inputTensor22_dfl_0, 2); - - IShuffleLayer* shuffle22_1 = network->addShuffle(*cat22_1->getOutput(0)); - shuffle22_1->setReshapeDimensions(Dims2{ first_dim, (kInputH / 16) * (kInputW / 16) }); - ISliceLayer* split22_1_0 = network->addSlice(*shuffle22_1->getOutput(0), Dims2{ 0, 0 }, Dims2{ 64, (kInputH / 16) * (kInputW / 16) }, Dims2{ 1,1 }); - ISliceLayer* split22_1_1 = network->addSlice(*shuffle22_1->getOutput(0), Dims2{ 64, 0 }, Dims2{ kNumClass, (kInputH / 16) * (kInputW / 16) }, Dims2{ 1,1 }); - IShuffleLayer* dfl22_1 = DFL(network, weightMap, *split22_1_0->getOutput(0), 4, (kInputH / 16) * (kInputW / 16), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_1[] = { dfl22_1->getOutput(0), split22_1_1->getOutput(0) }; - IConcatenationLayer* cat22_dfl_1 = network->addConcatenation(inputTensor22_dfl_1, 2); - - IShuffleLayer* shuffle22_2 = network->addShuffle(*cat22_2->getOutput(0)); - shuffle22_2->setReshapeDimensions(Dims2{ first_dim, (kInputH / 32) * (kInputW / 32) }); - ISliceLayer* split22_2_0 = network->addSlice(*shuffle22_2->getOutput(0), Dims2{ 0, 0 }, Dims2{ 64, (kInputH / 32) * (kInputW / 32) }, Dims2{ 1,1 }); - ISliceLayer* split22_2_1 = network->addSlice(*shuffle22_2->getOutput(0), Dims2{ 64, 0 }, Dims2{ kNumClass, (kInputH / 32) * (kInputW / 32) }, Dims2{ 1,1 }); - IShuffleLayer* dfl22_2 = DFL(network, weightMap, *split22_2_0->getOutput(0), 4, (kInputH / 32) * (kInputW / 32), 1, 1, 0, "model.22.dfl.conv.weight"); - ITensor* inputTensor22_dfl_2[] = { dfl22_2->getOutput(0), split22_2_1->getOutput(0) }; - IConcatenationLayer* cat22_dfl_2 = network->addConcatenation(inputTensor22_dfl_2, 2); - - IPluginV2Layer* yolo = addYoLoLayer(network, std::vector{cat22_dfl_0, cat22_dfl_1, cat22_dfl_2}); + nvinfer1::IShuffleLayer* shuffle22_0 = network->addShuffle(*cat22_0->getOutput(0)); + shuffle22_0->setReshapeDimensions(nvinfer1::Dims2{ 64 + kNumClass, (kInputH / 8) * (kInputW / 8) }); + + nvinfer1::ISliceLayer* split22_0_0 = network->addSlice(*shuffle22_0->getOutput(0), nvinfer1::Dims2{ 0, 0 }, nvinfer1::Dims2{ 64, (kInputH / 8) * (kInputW / 8) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::ISliceLayer* split22_0_1 = network->addSlice(*shuffle22_0->getOutput(0), nvinfer1::Dims2{ 64, 0 }, nvinfer1::Dims2{ kNumClass, (kInputH / 8) * (kInputW / 8) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::IShuffleLayer* dfl22_0 = DFL(network, weightMap, *split22_0_0->getOutput(0), 4, (kInputH / 8) * (kInputW / 8), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_0[] = { dfl22_0->getOutput(0), split22_0_1->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_dfl_0 = network->addConcatenation(inputTensor22_dfl_0, 2); + + nvinfer1::IShuffleLayer* shuffle22_1 = network->addShuffle(*cat22_1->getOutput(0)); + shuffle22_1->setReshapeDimensions(nvinfer1::Dims2{ 64 + kNumClass, (kInputH / 16) * (kInputW / 16) }); + nvinfer1::ISliceLayer* split22_1_0 = network->addSlice(*shuffle22_1->getOutput(0), nvinfer1::Dims2{ 0, 0 }, nvinfer1::Dims2{ 64, (kInputH / 16) * (kInputW / 16) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::ISliceLayer* split22_1_1 = network->addSlice(*shuffle22_1->getOutput(0), nvinfer1::Dims2{ 64, 0 }, nvinfer1::Dims2{ kNumClass, (kInputH / 16) * (kInputW / 16) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::IShuffleLayer* dfl22_1 = DFL(network, weightMap, *split22_1_0->getOutput(0), 4, (kInputH / 16) * (kInputW / 16), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_1[] = { dfl22_1->getOutput(0), split22_1_1->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_dfl_1 = network->addConcatenation(inputTensor22_dfl_1, 2); + + nvinfer1::IShuffleLayer* shuffle22_2 = network->addShuffle(*cat22_2->getOutput(0)); + shuffle22_2->setReshapeDimensions(nvinfer1::Dims2{ 64 + kNumClass, (kInputH / 32) * (kInputW / 32) }); + nvinfer1::ISliceLayer* split22_2_0 = network->addSlice(*shuffle22_2->getOutput(0), nvinfer1::Dims2{ 0, 0 }, nvinfer1::Dims2{ 64, (kInputH / 32) * (kInputW / 32) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::ISliceLayer* split22_2_1 = network->addSlice(*shuffle22_2->getOutput(0), nvinfer1::Dims2{ 64, 0 }, nvinfer1::Dims2{ kNumClass, (kInputH / 32) * (kInputW / 32) }, nvinfer1::Dims2{ 1,1 }); + nvinfer1::IShuffleLayer* dfl22_2 = DFL(network, weightMap, *split22_2_0->getOutput(0), 4, (kInputH / 32) * (kInputW / 32), 1, 1, 0, "model.22.dfl.conv.weight"); + nvinfer1::ITensor* inputTensor22_dfl_2[] = { dfl22_2->getOutput(0), split22_2_1->getOutput(0) }; + nvinfer1::IConcatenationLayer* cat22_dfl_2 = network->addConcatenation(inputTensor22_dfl_2, 2); + + nvinfer1::IPluginV2Layer* yolo = addYoLoLayer(network, std::vector{cat22_dfl_0, cat22_dfl_1, cat22_dfl_2}); yolo->getOutput(0)->setName(kOutputTensorName); network->markOutput(*yolo->getOutput(0)); - builder->setMaxBatchSize(batchsize); + builder->setMaxBatchSize(kBatchSize); config->setMaxWorkspaceSize(16 * (1 << 20)); #if defined(USE_FP16) - config->setFlag(BuilderFlag::kFP16); + config->setFlag(nvinfer1::BuilderFlag::kFP16); #elif defined(USE_INT8) std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; assert(builder->platformHasFastInt8()); - config->setFlag(BuilderFlag::kINT8); - Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, "./coco_calib/", "int8calib.table", kInputTensorName); + config->setFlag(nvinfer1::BuilderFlag::kINT8); + nvinfer1::IInt8EntropyCalibrator2* calibrator = new Calibrator(1, kInputW, kInputH, "../calibrator/", "int8calib.table", kInputTensorName); config->setInt8Calibrator(calibrator); #endif std::cout << "Building engine, please wait for a while..." << std::endl; - IHostMemory* serialized_model = builder->buildSerializedNetwork(*network, *config); + nvinfer1::IHostMemory* serialized_model = builder->buildSerializedNetwork(*network, *config); std::cout << "Build engine successfully!" << std::endl; delete network; diff --git a/yolov8/src/postprocess.cpp b/yolov8/src/postprocess.cpp index d09bf79e..38c482c6 100644 --- a/yolov8/src/postprocess.cpp +++ b/yolov8/src/postprocess.cpp @@ -82,7 +82,6 @@ void batch_nms(std::vector> &res_batch, float *output, in } } - void process_decode_ptr_host(std::vector &res, const float* decode_ptr_host, int bbox_element, cv::Mat& img, int count) { Detection det; for (int i = 0; i < count; i++) { @@ -110,7 +109,6 @@ void batch_process(std::vector> &res_batch, const float* } } - void draw_bbox(std::vector &img_batch, std::vector> &res_batch) { for (size_t i = 0; i < img_batch.size(); i++) { auto &res = res_batch[i]; diff --git a/yolov8/src/postprocess.cu b/yolov8/src/postprocess.cu new file mode 100644 index 00000000..dd5ae2c9 --- /dev/null +++ b/yolov8/src/postprocess.cu @@ -0,0 +1,98 @@ +// +// Created by lindsay on 23-7-17. +// +#include "types.h" +#include "postprocess.h" + +static __global__ void +decode_kernel(float *predict, int num_bboxes, float confidence_threshold, float *parray, int max_objects) { + + float count = predict[0]; + int position = (blockDim.x * blockIdx.x + threadIdx.x); + if (position >= count) + return; + float *pitem = predict + 1 + position * 6; + int index = atomicAdd(parray, 1); + if (index >= max_objects) + return; + float confidence = pitem[4]; + if (confidence < confidence_threshold) + return; + float left = pitem[0]; + float top = pitem[1]; + float right = pitem[2]; + float bottom = pitem[3]; + float label = pitem[5]; + float *pout_item = parray + 1 + index * bbox_element; + *pout_item++ = left; + *pout_item++ = top; + *pout_item++ = right; + *pout_item++ = bottom; + *pout_item++ = confidence; + *pout_item++ = label; + *pout_item++ = 1; // 1 = keep, 0 = ignore +} + +static __device__ float +box_iou(float aleft, float atop, float aright, float abottom, float bleft, float btop, float bright, float bbottom) { + + float cleft = max(aleft, bleft); + float ctop = max(atop, btop); + float cright = min(aright, bright); + float cbottom = min(abottom, bbottom); + + float c_area = max(cright - cleft, 0.0f) * max(cbottom - ctop, 0.0f); + if (c_area == 0.0f) + return 0.0f; + + float a_area = max(0.0f, aright - aleft) * max(0.0f, abottom - atop); + float b_area = max(0.0f, bright - bleft) * max(0.0f, bbottom - btop); + return c_area / (a_area + b_area - c_area); +} + +static __global__ void nms_kernel(float *bboxes, int max_objects, float threshold) { + + int position = (blockDim.x * blockIdx.x + threadIdx.x); + int count = bboxes[0]; + + // float count = 0.0f; + if (position >= count) + return; + + float *pcurrent = bboxes + 1 + position * bbox_element; + for (int i = 1; i < count; ++i) { + float *pitem = bboxes + 1 + i * bbox_element; + if (i == position || pcurrent[5] != pitem[5]) continue; + + if (pitem[4] >= pcurrent[4]) { + if (pitem[4] == pcurrent[4] && i < position) + continue; + + float iou = box_iou( + pcurrent[0], pcurrent[1], pcurrent[2], pcurrent[3], + pitem[0], pitem[1], pitem[2], pitem[3] + ); + + if (iou > threshold) { + pcurrent[6] = 0; + return; + } + } + } +} + +void cuda_decode(float *predict, int num_bboxes, float confidence_threshold, float *parray, int max_objects, + cudaStream_t stream) { + int block = 256; + int grid = ceil(num_bboxes / (float) block); + decode_kernel << < + grid, block, 0, stream >> > ((float *) predict, num_bboxes, confidence_threshold, parray, max_objects); + +} + +void cuda_nms(float *parray, float nms_threshold, int max_objects, cudaStream_t stream) { + int block = max_objects < 256 ? max_objects : 256; + int grid = ceil(max_objects / (float) block); + nms_kernel << < grid, block, 0, stream >> > (parray, max_objects, nms_threshold); + +} diff --git a/yolov8/src/preprocess.cu b/yolov8/src/preprocess.cu index 376c2cba..14d9e778 100644 --- a/yolov8/src/preprocess.cu +++ b/yolov8/src/preprocess.cu @@ -87,82 +87,7 @@ warpaffine_kernel(uint8_t *src, int src_line_size, int src_width, int src_height *pdst_c2 = c2; } -static __global__ void -decode_kernel(float *predict, int num_bboxes, float confidence_threshold, float *parray, int max_objects) { - - float count = predict[0]; - int position = (blockDim.x * blockIdx.x + threadIdx.x); - if (position >= count) - return; - float *pitem = predict + 1 + position * 6; - int index = atomicAdd(parray, 1); - if (index >= max_objects) - return; - float confidence = pitem[4]; - if (confidence < confidence_threshold) - return; - float left = pitem[0]; - float top = pitem[1]; - float right = pitem[2]; - float bottom = pitem[3]; - float label = pitem[5]; - float *pout_item = parray + 1 + index * bbox_element; - *pout_item++ = left; - *pout_item++ = top; - *pout_item++ = right; - *pout_item++ = bottom; - *pout_item++ = confidence; - *pout_item++ = label; - *pout_item++ = 1; // 1 = keep, 0 = ignore -} - -static __device__ float -box_iou(float aleft, float atop, float aright, float abottom, float bleft, float btop, float bright, float bbottom) { - - float cleft = max(aleft, bleft); - float ctop = max(atop, btop); - float cright = min(aright, bright); - float cbottom = min(abottom, bbottom); - - float c_area = max(cright - cleft, 0.0f) * max(cbottom - ctop, 0.0f); - if (c_area == 0.0f) - return 0.0f; - - float a_area = max(0.0f, aright - aleft) * max(0.0f, abottom - atop); - float b_area = max(0.0f, bright - bleft) * max(0.0f, bbottom - btop); - return c_area / (a_area + b_area - c_area); -} - -static __global__ void nms_kernel(float *bboxes, int max_objects, float threshold) { - - int position = (blockDim.x * blockIdx.x + threadIdx.x); - int count = bboxes[0]; - -// float count = 0.0f; - if (position >= count) - return; - - float *pcurrent = bboxes + 1 + position * bbox_element; - for (int i = 1; i < count; ++i) { - float *pitem = bboxes + 1 + i * bbox_element; - if (i == position || pcurrent[5] != pitem[5]) continue; - if (pitem[4] >= pcurrent[4]) { - if (pitem[4] == pcurrent[4] && i < position) - continue; - - float iou = box_iou( - pcurrent[0], pcurrent[1], pcurrent[2], pcurrent[3], - pitem[0], pitem[1], pitem[2], pitem[3] - ); - - if (iou > threshold) { - pcurrent[6] = 0; - return; - } - } - } -} void cuda_preprocess(uint8_t *src, int src_width, int src_height, float *dst, int dst_width, int dst_height, @@ -210,21 +135,7 @@ void cuda_batch_preprocess(std::vector &img_batch, } -void cuda_decode(float *predict, int num_bboxes, float confidence_threshold, float *parray, int max_objects, - cudaStream_t stream) { - int block = 256; - int grid = ceil(num_bboxes / (float) block); - decode_kernel << < - grid, block, 0, stream >> > ((float *) predict, num_bboxes, confidence_threshold, parray, max_objects); -} - -void cuda_nms(float *parray, float nms_threshold, int max_objects, cudaStream_t stream) { - int block = max_objects < 256 ? max_objects : 256; - int grid = ceil(max_objects / (float) block); - nms_kernel << < grid, block, 0, stream >> > (parray, max_objects, nms_threshold); - -} void cuda_preprocess_init(int max_image_size) {