Skip to content

Commit

Permalink
[Backend] TRT backend & PP-Infer backend support pinned memory (Paddl…
Browse files Browse the repository at this point in the history
…ePaddle#403)

* TRT backend use pinned memory

* refine fd tensor pinned memory logic

* TRT enable pinned memory configurable

* paddle inference support pinned memory

* pinned memory pybindings

Co-authored-by: Jason <[email protected]>
  • Loading branch information
wang-xinyu and jiangjiajun authored Oct 21, 2022
1 parent 8dbc1f1 commit 43d8611
Show file tree
Hide file tree
Showing 14 changed files with 120 additions and 18 deletions.
2 changes: 2 additions & 0 deletions fastdeploy/backends/paddle/paddle_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
namespace fastdeploy {

void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
option_ = option;
if (option.use_gpu) {
config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id);
if (option.enable_trt) {
Expand Down Expand Up @@ -190,6 +191,7 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
outputs->resize(outputs_desc_.size());
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
auto handle = predictor_->GetOutputHandle(outputs_desc_[i].name);
(*outputs)[i].is_pinned_memory = option_.enable_pinned_memory;
CopyTensorToCpu(handle, &((*outputs)[i]));
}
return true;
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/backends/paddle/paddle_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct PaddleBackendOption {
int gpu_mem_init_size = 100;
// gpu device id
int gpu_id = 0;
bool enable_pinned_memory = false;

std::vector<std::string> delete_pass_names = {};
};
Expand Down Expand Up @@ -105,6 +106,7 @@ class PaddleBackend : public BaseBackend {
std::map<std::string, std::vector<int>>* opt_shape) const;
void SetTRTDynamicShapeToConfig(const PaddleBackendOption& option);
#endif
PaddleBackendOption option_;
paddle_infer::Config config_;
std::shared_ptr<paddle_infer::Predictor> predictor_;
std::vector<TensorInfo> inputs_desc_;
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/backends/paddle/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
std::vector<int64_t> shape;
auto tmp_shape = tensor->shape();
shape.assign(tmp_shape.begin(), tmp_shape.end());
fd_tensor->Allocate(shape, fd_dtype, tensor->name());
fd_tensor->Resize(shape, fd_dtype, tensor->name());
if (fd_tensor->dtype == FDDataType::FP32) {
tensor->CopyToCpu(static_cast<float*>(fd_tensor->MutableData()));
return;
Expand Down
29 changes: 19 additions & 10 deletions fastdeploy/backends/tensorrt/trt_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,17 +306,21 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,

SetInputs(inputs);
AllocateOutputsBuffer(outputs);

if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
FDERROR << "Failed to Infer with TensorRT." << std::endl;
return false;
}
for (size_t i = 0; i < outputs->size(); ++i) {
FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(),
outputs_buffer_[(*outputs)[i].name].data(),
outputs_device_buffer_[(*outputs)[i].name].data(),
(*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost,
stream_) == 0,
"[ERROR] Error occurs while copy memory from GPU to CPU.");
}
FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
"[ERROR] Error occurs while sync cuda stream.");

return true;
}

Expand All @@ -332,10 +336,10 @@ void TrtBackend::GetInputOutputInfo() {
auto dtype = engine_->getBindingDataType(i);
if (engine_->bindingIsInput(i)) {
inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
inputs_buffer_[name] = FDDeviceBuffer(dtype);
inputs_device_buffer_[name] = FDDeviceBuffer(dtype);
} else {
outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
outputs_buffer_[name] = FDDeviceBuffer(dtype);
outputs_device_buffer_[name] = FDDeviceBuffer(dtype);
}
}
bindings_.resize(num_binds);
Expand All @@ -357,30 +361,31 @@ void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
"please use INT32 input");
} else {
// no copy
inputs_buffer_[item.name].SetExternalData(dims, item.Data());
inputs_device_buffer_[item.name].SetExternalData(dims, item.Data());
}
} else {
// Allocate input buffer memory
inputs_buffer_[item.name].resize(dims);
inputs_device_buffer_[item.name].resize(dims);

// copy from cpu to gpu
if (item.dtype == FDDataType::INT64) {
int64_t* data = static_cast<int64_t*>(const_cast<void*>(item.Data()));
std::vector<int32_t> casted_data(data, data + item.Numel());
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(),
FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
static_cast<void*>(casted_data.data()),
item.Nbytes() / 2, cudaMemcpyHostToDevice,
stream_) == 0,
"Error occurs while copy memory from CPU to GPU.");
} else {
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(), item.Data(),
FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
item.Data(),
item.Nbytes(), cudaMemcpyHostToDevice,
stream_) == 0,
"Error occurs while copy memory from CPU to GPU.");
}
}
// binding input buffer
bindings_[idx] = inputs_buffer_[item.name].data();
bindings_[idx] = inputs_device_buffer_[item.name].data();
}
}

Expand All @@ -399,15 +404,19 @@ void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs) {
"Cannot find output: %s of tensorrt network from the original model.",
outputs_desc_[i].name.c_str());
auto ori_idx = iter->second;

// set user's outputs info
std::vector<int64_t> shape(output_dims.d,
output_dims.d + output_dims.nbDims);
(*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory;
(*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype),
outputs_desc_[i].name);

// Allocate output buffer memory
outputs_buffer_[outputs_desc_[i].name].resize(output_dims);
outputs_device_buffer_[outputs_desc_[i].name].resize(output_dims);

// binding output buffer
bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data();
bindings_[idx] = outputs_device_buffer_[outputs_desc_[i].name].data();
}
}

Expand Down
5 changes: 3 additions & 2 deletions fastdeploy/backends/tensorrt/trt_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct TrtBackendOption {
std::map<std::string, std::vector<int32_t>> min_shape;
std::map<std::string, std::vector<int32_t>> opt_shape;
std::string serialize_file = "";
bool enable_pinned_memory = false;

// inside parameter, maybe remove next version
bool remove_multiclass_nms_ = false;
Expand Down Expand Up @@ -118,8 +119,8 @@ class TrtBackend : public BaseBackend {
std::vector<void*> bindings_;
std::vector<TrtValueInfo> inputs_desc_;
std::vector<TrtValueInfo> outputs_desc_;
std::map<std::string, FDDeviceBuffer> inputs_buffer_;
std::map<std::string, FDDeviceBuffer> outputs_buffer_;
std::map<std::string, FDDeviceBuffer> inputs_device_buffer_;
std::map<std::string, FDDeviceBuffer> outputs_device_buffer_;

std::string calibration_str_;

Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/backends/tensorrt/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ class FDGenericBuffer {
};

using FDDeviceBuffer = FDGenericBuffer<FDDeviceAllocator, FDDeviceFree>;
using FDDeviceHostBuffer = FDGenericBuffer<FDDeviceHostAllocator,
FDDeviceHostFree>;

class FDTrtLogger : public nvinfer1::ILogger {
public:
Expand Down
6 changes: 6 additions & 0 deletions fastdeploy/core/allocate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ bool FDDeviceAllocator::operator()(void** ptr, size_t size) const {

void FDDeviceFree::operator()(void* ptr) const { cudaFree(ptr); }

bool FDDeviceHostAllocator::operator()(void** ptr, size_t size) const {
return cudaMallocHost(ptr, size) == cudaSuccess;
}

void FDDeviceHostFree::operator()(void* ptr) const { cudaFreeHost(ptr); }

#endif

} // namespace fastdeploy
10 changes: 10 additions & 0 deletions fastdeploy/core/allocate.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ class FASTDEPLOY_DECL FDDeviceFree {
void operator()(void* ptr) const;
};

class FASTDEPLOY_DECL FDDeviceHostAllocator {
public:
bool operator()(void** ptr, size_t size) const;
};

class FASTDEPLOY_DECL FDDeviceHostFree {
public:
void operator()(void* ptr) const;
};

#endif

} // namespace fastdeploy
45 changes: 40 additions & 5 deletions fastdeploy/core/fd_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,27 @@ bool FDTensor::ReallocFn(size_t nbytes) {
"-DWITH_GPU=ON,"
"so this is an unexpected problem happend.");
#endif
} else {
if (is_pinned_memory) {
#ifdef WITH_GPU
size_t original_nbytes = Nbytes();
if (nbytes > original_nbytes) {
if (buffer_ != nullptr) {
FDDeviceHostFree()(buffer_);
}
FDDeviceHostAllocator()(&buffer_, nbytes);
}
return buffer_ != nullptr;
#else
FDASSERT(false,
"The FastDeploy FDTensor allocator didn't compile under "
"-DWITH_GPU=ON,"
"so this is an unexpected problem happend.");
#endif
}
buffer_ = realloc(buffer_, nbytes);
return buffer_ != nullptr;
}
buffer_ = realloc(buffer_, nbytes);
return buffer_ != nullptr;
}

void FDTensor::FreeFn() {
Expand All @@ -220,7 +238,13 @@ void FDTensor::FreeFn() {
FDDeviceFree()(buffer_);
#endif
} else {
FDHostFree()(buffer_);
if (is_pinned_memory) {
#ifdef WITH_GPU
FDDeviceHostFree()(buffer_);
#endif
} else {
FDHostFree()(buffer_);
}
}
buffer_ = nullptr;
}
Expand All @@ -231,15 +255,26 @@ void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes) {
#ifdef WITH_GPU
FDASSERT(cudaMemcpy(dst, src, nbytes, cudaMemcpyDeviceToDevice) == 0,
"[ERROR] Error occurs while copy memory from GPU to GPU");

#else
FDASSERT(false,
"The FastDeploy didn't compile under -DWITH_GPU=ON, so copying "
"gpu buffer is "
"an unexpected problem happend.");
#endif
} else {
std::memcpy(dst, src, nbytes);
if (is_pinned_memory) {
#ifdef WITH_GPU
FDASSERT(cudaMemcpy(dst, src, nbytes, cudaMemcpyHostToHost) == 0,
"[ERROR] Error occurs while copy memory from host to host");
#else
FDASSERT(false,
"The FastDeploy didn't compile under -DWITH_GPU=ON, so copying "
"gpu buffer is "
"an unexpected problem happend.");
#endif
} else {
std::memcpy(dst, src, nbytes);
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/core/fd_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ struct FASTDEPLOY_DECL FDTensor {
// so we can skip data transfer, which may improve the efficience
Device device = Device::CPU;

// Whether the data buffer is in pinned memory, which is allocated
// with cudaMallocHost()
bool is_pinned_memory = false;

// if the external data is not on CPU, we use this temporary buffer
// to transfer data to CPU at some cases we need to visit the
// other devices' data
Expand Down
3 changes: 3 additions & 0 deletions fastdeploy/pybind/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ void BindRuntime(pybind11::module& m) {
.def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16)
.def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16)
.def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile)
.def("enable_pinned_memory", &RuntimeOption::EnablePinnedMemory)
.def("disable_pinned_memory", &RuntimeOption::DisablePinnedMemory)
.def("enable_paddle_trt_collect_shape", &RuntimeOption::EnablePaddleTrtCollectShape)
.def("disable_paddle_trt_collect_shape", &RuntimeOption::DisablePaddleTrtCollectShape)
.def_readwrite("model_file", &RuntimeOption::model_file)
Expand Down Expand Up @@ -200,6 +202,7 @@ void BindRuntime(pybind11::module& m) {
.def("numel", &FDTensor::Numel)
.def("nbytes", &FDTensor::Nbytes)
.def_readwrite("name", &FDTensor::name)
.def_readwrite("is_pinned_memory", &FDTensor::is_pinned_memory)
.def_readonly("shape", &FDTensor::shape)
.def_readonly("dtype", &FDTensor::dtype)
.def_readonly("device", &FDTensor::device);
Expand Down
7 changes: 7 additions & 0 deletions fastdeploy/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,10 @@ void RuntimeOption::EnableTrtFP16() { trt_enable_fp16 = true; }

void RuntimeOption::DisableTrtFP16() { trt_enable_fp16 = false; }

void RuntimeOption::EnablePinnedMemory() { enable_pinned_memory = true; }

void RuntimeOption::DisablePinnedMemory() { enable_pinned_memory = false; }

void RuntimeOption::SetTrtCacheFile(const std::string& cache_file_path) {
trt_serialize_file = cache_file_path;
}
Expand Down Expand Up @@ -503,6 +507,7 @@ void Runtime::CreatePaddleBackend() {
pd_option.gpu_id = option.device_id;
pd_option.delete_pass_names = option.pd_delete_pass_names;
pd_option.cpu_thread_num = option.cpu_thread_num;
pd_option.enable_pinned_memory = option.enable_pinned_memory;
#ifdef ENABLE_TRT_BACKEND
if (pd_option.use_gpu && option.pd_enable_trt) {
pd_option.enable_trt = true;
Expand All @@ -516,6 +521,7 @@ void Runtime::CreatePaddleBackend() {
trt_option.min_shape = option.trt_min_shape;
trt_option.opt_shape = option.trt_opt_shape;
trt_option.serialize_file = option.trt_serialize_file;
trt_option.enable_pinned_memory = option.enable_pinned_memory;
pd_option.trt_option = trt_option;
}
#endif
Expand Down Expand Up @@ -606,6 +612,7 @@ void Runtime::CreateTrtBackend() {
trt_option.min_shape = option.trt_min_shape;
trt_option.opt_shape = option.trt_opt_shape;
trt_option.serialize_file = option.trt_serialize_file;
trt_option.enable_pinned_memory = option.enable_pinned_memory;

// TODO(jiangjiajun): inside usage, maybe remove this later
trt_option.remove_multiclass_nms_ = option.remove_multiclass_nms_;
Expand Down
11 changes: 11 additions & 0 deletions fastdeploy/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,15 @@ struct FASTDEPLOY_DECL RuntimeOption {
*/
void SetTrtCacheFile(const std::string& cache_file_path);

/**
* @brief Enable pinned memory. Pinned memory can be utilized to speedup the data transfer between CPU and GPU. Currently it's only suppurted in TRT backend and Paddle Inference backend.
*/
void EnablePinnedMemory();

/**
* @brief Disable pinned memory
*/
void DisablePinnedMemory();

/**
* @brief Enable to collect shape in paddle trt backend
Expand All @@ -223,6 +232,8 @@ struct FASTDEPLOY_DECL RuntimeOption {

Device device = Device::CPU;

bool enable_pinned_memory = false;

// ======Only for ORT Backend========
// -1 means use default value by ort
// 0: ORT_DISABLE_ALL 1: ORT_ENABLE_BASIC 2: ORT_ENABLE_EXTENDED 3:
Expand Down
10 changes: 10 additions & 0 deletions python/fastdeploy/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,16 @@ def disable_trt_fp16(self):
"""
return self._option.disable_trt_fp16()

def enable_pinned_memory(self):
"""Enable pinned memory. Pinned memory can be utilized to speedup the data transfer between CPU and GPU. Currently it's only suppurted in TRT backend and Paddle Inference backend.
"""
return self._option.enable_pinned_memory()

def disable_pinned_memory(self):
"""Disable pinned memory.
"""
return self._option.disable_pinned_memory()

def enable_paddle_to_trt(self):
"""While using TensorRT backend, enable_paddle_to_trt() will change to use Paddle Inference backend, and use its integrated TensorRT instead.
"""
Expand Down

0 comments on commit 43d8611

Please sign in to comment.