Skip to content

Commit

Permalink
[nvJPEG] Integrate nvJPEG decoder (PaddlePaddle#1288)
Browse files Browse the repository at this point in the history
* nvjpeg cmake

* add common decoder, nvjpeg decoder and add image name predict api

* ppclas support nvjpeg decoder

* remove useless comments

* image decoder support opencv

* nvjpeg decode fallback to opencv

* fdtensor add nbytes_allocated

* single image decode api

* fix bug

* add pybind

* ignore nvjpeg on jetson

* fix cmake in

* predict on fdmat

* remove image names predict api, add image decoder tutorial

* Update __init__.py

* fix pybind
  • Loading branch information
wang-xinyu authored Feb 17, 2023
1 parent e3a7ab4 commit efa4656
Show file tree
Hide file tree
Showing 25 changed files with 875 additions and 44 deletions.
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,16 @@ if(WITH_GPU)
include_directories(${CUDA_DIRECTORY}/include)
if(WIN32)
find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib/x64)
find_library(NVJPEG_LIB nvjpeg ${CUDA_DIRECTORY}/lib/x64)
add_definitions(-DENABLE_NVJPEG)
else()
find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib64)
if(NOT BUILD_ON_JETSON)
find_library(NVJPEG_LIB nvjpeg ${CUDA_DIRECTORY}/lib64)
add_definitions(-DENABLE_NVJPEG)
endif()
endif()
list(APPEND DEPEND_LIBS ${CUDA_LIB})
list(APPEND DEPEND_LIBS ${CUDA_LIB} ${NVJPEG_LIB})

# build CUDA source files in fastdeploy, CUDA source files include CUDA preprocessing, TRT plugins, etc.
enable_language(CUDA)
Expand Down
10 changes: 7 additions & 3 deletions FastDeploy.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -169,21 +169,25 @@ if(ENABLE_POROS_BACKEND)
endif()

if(WITH_GPU)
if (NOT CUDA_DIRECTORY)
if(NOT CUDA_DIRECTORY)
set(CUDA_DIRECTORY "/usr/local/cuda")
endif()
if(WIN32)
find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib/x64)
find_library(NVJPEG_LIB nvjpeg ${CUDA_DIRECTORY}/lib/x64)
else()
find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib64)
if(NOT BUILD_ON_JETSON)
find_library(NVJPEG_LIB nvjpeg ${CUDA_DIRECTORY}/lib64)
endif()
endif()
if(NOT CUDA_LIB)
message(FATAL_ERROR "[FastDeploy] Cannot find library cudart in ${CUDA_DIRECTORY}, Please define CUDA_DIRECTORY, e.g -DCUDA_DIRECTORY=/path/to/cuda")
endif()
list(APPEND FASTDEPLOY_LIBS ${CUDA_LIB})
list(APPEND FASTDEPLOY_LIBS ${CUDA_LIB} ${NVJPEG_LIB})
list(APPEND FASTDEPLOY_INCS ${CUDA_DIRECTORY}/include)

if (ENABLE_TRT_BACKEND)
if(ENABLE_TRT_BACKEND)
if(BUILD_ON_JETSON)
find_library(TRT_INFER_LIB nvinfer /usr/lib/aarch64-linux-gnu/)
find_library(TRT_ONNX_LIB nvonnxparser /usr/lib/aarch64-linux-gnu/)
Expand Down
14 changes: 10 additions & 4 deletions fastdeploy/core/fd_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,13 @@ void FDTensor::PrintInfo(const std::string& prefix) const {
bool FDTensor::ReallocFn(size_t nbytes) {
if (device == Device::GPU) {
#ifdef WITH_GPU
size_t original_nbytes = Nbytes();
size_t original_nbytes = nbytes_allocated;
if (nbytes > original_nbytes) {
if (buffer_ != nullptr) {
FDDeviceFree()(buffer_);
}
FDDeviceAllocator()(&buffer_, nbytes);
nbytes_allocated = nbytes;
}
return buffer_ != nullptr;
#else
Expand All @@ -262,12 +263,13 @@ bool FDTensor::ReallocFn(size_t nbytes) {
} else {
if (is_pinned_memory) {
#ifdef WITH_GPU
size_t original_nbytes = Nbytes();
size_t original_nbytes = nbytes_allocated;
if (nbytes > original_nbytes) {
if (buffer_ != nullptr) {
FDDeviceHostFree()(buffer_);
}
FDDeviceHostAllocator()(&buffer_, nbytes);
nbytes_allocated = nbytes;
}
return buffer_ != nullptr;
#else
Expand All @@ -278,6 +280,7 @@ bool FDTensor::ReallocFn(size_t nbytes) {
#endif
}
buffer_ = realloc(buffer_, nbytes);
nbytes_allocated = nbytes;
return buffer_ != nullptr;
}
}
Expand All @@ -299,6 +302,7 @@ void FDTensor::FreeFn() {
}
}
buffer_ = nullptr;
nbytes_allocated = 0;
}
}

Expand Down Expand Up @@ -380,7 +384,7 @@ FDTensor::FDTensor(const FDTensor& other)
device_id(other.device_id) {
// Copy buffer
if (other.buffer_ == nullptr) {
buffer_ = nullptr;
FreeFn();
} else {
size_t nbytes = Nbytes();
FDASSERT(ReallocFn(nbytes),
Expand All @@ -396,7 +400,8 @@ FDTensor::FDTensor(FDTensor&& other)
dtype(other.dtype),
external_data_ptr(other.external_data_ptr),
device(other.device),
device_id(other.device_id) {
device_id(other.device_id),
nbytes_allocated(other.nbytes_allocated) {
other.name = "";
// Note(zhoushunjie): Avoid double free.
other.buffer_ = nullptr;
Expand Down Expand Up @@ -435,6 +440,7 @@ FDTensor& FDTensor::operator=(FDTensor&& other) {
dtype = other.dtype;
device = other.device;
device_id = other.device_id;
nbytes_allocated = other.nbytes_allocated;

other.name = "";
// Note(zhoushunjie): Avoid double free.
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/core/fd_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ struct FASTDEPLOY_DECL FDTensor {
// other devices' data
std::vector<int8_t> temporary_cpu_buffer;

// The number of bytes allocated so far.
// When resizing GPU memory, we will free and realloc the memory only if the
// required size is larger than this value.
size_t nbytes_allocated = 0;

// Get data buffer pointer
void* MutableData();

Expand Down
40 changes: 28 additions & 12 deletions fastdeploy/vision/classification/ppcls/model.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "fastdeploy/vision/classification/ppcls/model.h"

#include "fastdeploy/utils/unique_ptr.h"

namespace fastdeploy {
Expand All @@ -23,7 +24,8 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) : preprocessor_(config_file) {
const ModelFormat& model_format)
: preprocessor_(config_file) {
if (model_format == ModelFormat::PADDLE) {
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::ORT,
Backend::LITE};
Expand All @@ -32,24 +34,24 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file,
valid_ascend_backends = {Backend::LITE};
valid_kunlunxin_backends = {Backend::LITE};
valid_ipu_backends = {Backend::PDINFER};
}else if (model_format == ModelFormat::SOPHGO) {
} else if (model_format == ModelFormat::SOPHGO) {
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
}
else {
} else {
valid_cpu_backends = {Backend::ORT, Backend::OPENVINO};
valid_gpu_backends = {Backend::ORT, Backend::TRT};
valid_rknpu_backends = {Backend::RKNPU2};
}

runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}

std::unique_ptr<PaddleClasModel> PaddleClasModel::Clone() const {
std::unique_ptr<PaddleClasModel> clone_model = utils::make_unique<PaddleClasModel>(PaddleClasModel(*this));
std::unique_ptr<PaddleClasModel> PaddleClasModel::Clone() const {
std::unique_ptr<PaddleClasModel> clone_model =
utils::make_unique<PaddleClasModel>(PaddleClasModel(*this));
clone_model->SetRuntime(clone_model->CloneRuntime());
return clone_model;
}
Expand All @@ -71,17 +73,30 @@ bool PaddleClasModel::Predict(cv::Mat* im, ClassifyResult* result, int topk) {
}

bool PaddleClasModel::Predict(const cv::Mat& im, ClassifyResult* result) {
FDMat mat = WrapMat(im);
return Predict(mat, result);
}

bool PaddleClasModel::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<ClassifyResult>* results) {
std::vector<FDMat> mats = WrapMat(images);
return BatchPredict(mats, results);
}

bool PaddleClasModel::Predict(const FDMat& mat, ClassifyResult* result) {
std::vector<ClassifyResult> results;
if (!BatchPredict({im}, &results)) {
std::vector<FDMat> mats = {mat};
if (!BatchPredict(mats, &results)) {
return false;
}
*result = std::move(results[0]);
return true;
}

bool PaddleClasModel::BatchPredict(const std::vector<cv::Mat>& images, std::vector<ClassifyResult>* results) {
std::vector<FDMat> fd_images = WrapMat(images);
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
bool PaddleClasModel::BatchPredict(const std::vector<FDMat>& mats,
std::vector<ClassifyResult>* results) {
std::vector<FDMat> fd_mats = mats;
if (!preprocessor_.Run(&fd_mats, &reused_input_tensors_)) {
FDERROR << "Failed to preprocess the input image." << std::endl;
return false;
}
Expand All @@ -92,7 +107,8 @@ bool PaddleClasModel::BatchPredict(const std::vector<cv::Mat>& images, std::vect
}

if (!postprocessor_.Run(reused_output_tensors_, results)) {
FDERROR << "Failed to postprocess the inference results by runtime." << std::endl;
FDERROR << "Failed to postprocess the inference results by runtime."
<< std::endl;
return false;
}

Expand Down
17 changes: 17 additions & 0 deletions fastdeploy/vision/classification/ppcls/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,23 @@ class FASTDEPLOY_DECL PaddleClasModel : public FastDeployModel {
virtual bool BatchPredict(const std::vector<cv::Mat>& imgs,
std::vector<ClassifyResult>* results);

/** \brief Predict the classification result for an input image
*
* \param[in] mat The input mat
* \param[in] result The output classification result
* \return true if the prediction successed, otherwise false
*/
virtual bool Predict(const FDMat& mat, ClassifyResult* result);

/** \brief Predict the classification results for a batch of input images
*
* \param[in] mats, The input mat list
* \param[in] results The output classification result list
* \return true if the prediction successed, otherwise false
*/
virtual bool BatchPredict(const std::vector<FDMat>& mats,
std::vector<ClassifyResult>* results);

/// Get preprocessor reference of PaddleClasModel
virtual PaddleClasPreprocessor& GetPreprocessor() {
return preprocessor_;
Expand Down
112 changes: 112 additions & 0 deletions fastdeploy/vision/common/image_decoder/image_decoder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "fastdeploy/vision/common/image_decoder/image_decoder.h"

#include "opencv2/imgcodecs.hpp"

namespace fastdeploy {
namespace vision {

ImageDecoder::ImageDecoder(ImageDecoderLib lib) {
if (lib == ImageDecoderLib::NVJPEG) {
#ifdef ENABLE_NVJPEG
nvjpeg::init_decoder(nvjpeg_params_);
#endif
}
lib_ = lib;
}

ImageDecoder::~ImageDecoder() {
if (lib_ == ImageDecoderLib::NVJPEG) {
#ifdef ENABLE_NVJPEG
nvjpeg::destroy_decoder(nvjpeg_params_);
#endif
}
}

bool ImageDecoder::Decode(const std::string& img_name, FDMat* mat) {
std::vector<FDMat> mats(1);
mats[0] = std::move(*mat);
if (!BatchDecode({img_name}, &mats)) {
return false;
}
*mat = std::move(mats[0]);
return true;
}

bool ImageDecoder::BatchDecode(const std::vector<std::string>& img_names,
std::vector<FDMat>* mats) {
if (lib_ == ImageDecoderLib::OPENCV) {
return ImplByOpenCV(img_names, mats);
} else if (lib_ == ImageDecoderLib::NVJPEG) {
return ImplByNvJpeg(img_names, mats);
}
return true;
}

bool ImageDecoder::ImplByOpenCV(const std::vector<std::string>& img_names,
std::vector<FDMat>* mats) {
for (size_t i = 0; i < img_names.size(); ++i) {
cv::Mat im = cv::imread(img_names[i]);
(*mats)[i].SetMat(im);
(*mats)[i].layout = Layout::HWC;
(*mats)[i].SetWidth(im.cols);
(*mats)[i].SetHeight(im.rows);
(*mats)[i].SetChannels(im.channels());
}
return true;
}

bool ImageDecoder::ImplByNvJpeg(const std::vector<std::string>& img_names,
std::vector<FDMat>* mats) {
#ifdef ENABLE_NVJPEG
nvjpeg_params_.batch_size = img_names.size();
std::vector<nvjpegImage_t> output_imgs(nvjpeg_params_.batch_size);
std::vector<int> widths(nvjpeg_params_.batch_size);
std::vector<int> heights(nvjpeg_params_.batch_size);
// TODO(wangxinyu): support other output format
nvjpeg_params_.fmt = NVJPEG_OUTPUT_BGRI;
double total;
nvjpeg_params_.stream = (*mats)[0].Stream();

std::vector<FDTensor*> output_buffers;
for (size_t i = 0; i < mats->size(); ++i) {
FDASSERT((*mats)[i].output_cache != nullptr,
"The output_cache of FDMat was not set.");
output_buffers.push_back((*mats)[i].output_cache);
}

if (nvjpeg::process_images(img_names, nvjpeg_params_, total, output_imgs,
output_buffers, widths, heights)) {
// If nvJPEG decode failed, will fallback to OpenCV,
// e.g. png format is not supported by nvJPEG
FDWARNING << "nvJPEG decode failed, falling back to OpenCV for this batch"
<< std::endl;
return ImplByOpenCV(img_names, mats);
}

for (size_t i = 0; i < mats->size(); ++i) {
(*mats)[i].mat_type = ProcLib::CUDA;
(*mats)[i].layout = Layout::HWC;
(*mats)[i].SetTensor(output_buffers[i]);
}
#else
FDASSERT(false, "FastDeploy didn't compile with NVJPEG.");
#endif
return true;
}

} // namespace vision
} // namespace fastdeploy
Loading

0 comments on commit efa4656

Please sign in to comment.