Skip to content

Commit

Permalink
Fix directly visit member of FDTensor (PaddlePaddle#193)
Browse files Browse the repository at this point in the history
* optimize tensorrt usage

* format code

* fix input shape error for onnx model

* Remove some code directly visit FDTensor member (PaddlePaddle#192)

remove some code directly visit FDTensor member

* fix directly visit member of FDTensor

Co-authored-by: root <[email protected]>
  • Loading branch information
jiangjiajun and root authored Sep 6, 2022
1 parent 969531d commit e09ac18
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 48 deletions.
26 changes: 10 additions & 16 deletions csrc/fastdeploy/backends/ort/ort_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,32 +164,27 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
return true;
}

void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor) {
void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor, const std::string& name) {
const auto info = value.GetTensorTypeAndShapeInfo();
const auto data_type = info.GetElementType();
size_t numel = info.GetElementCount();
tensor->shape = info.GetShape();

if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
tensor->data.resize(numel * sizeof(float));
memcpy(static_cast<void*>(tensor->Data()), value.GetTensorData<void*>(),
tensor->Allocate(info.GetShape(), FDDataType::FP32, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
numel * sizeof(float));
tensor->dtype = FDDataType::FP32;
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
tensor->data.resize(numel * sizeof(int32_t));
memcpy(static_cast<void*>(tensor->Data()), value.GetTensorData<void*>(),
tensor->Allocate(info.GetShape(), FDDataType::INT32, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
numel * sizeof(int32_t));
tensor->dtype = FDDataType::INT32;
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
tensor->data.resize(numel * sizeof(int64_t));
memcpy(static_cast<void*>(tensor->Data()), value.GetTensorData<void*>(),
tensor->Allocate(info.GetShape(), FDDataType::INT64, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
numel * sizeof(int64_t));
tensor->dtype = FDDataType::INT64;
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
tensor->data.resize(numel * sizeof(double));
memcpy(static_cast<void*>(tensor->Data()), value.GetTensorData<void*>(),
tensor->Allocate(info.GetShape(), FDDataType::FP64, name);
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
numel * sizeof(double));
tensor->dtype = FDDataType::FP64;
} else {
FDASSERT(
false,
Expand Down Expand Up @@ -231,8 +226,7 @@ bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
std::vector<Ort::Value> ort_outputs = binding_->GetOutputValues();
outputs->resize(ort_outputs.size());
for (size_t i = 0; i < ort_outputs.size(); ++i) {
(*outputs)[i].name = outputs_desc_[i].name;
CopyToCpu(ort_outputs[i], &((*outputs)[i]));
CopyToCpu(ort_outputs[i], &((*outputs)[i]), outputs_desc_[i].name);
}

return true;
Expand Down
2 changes: 1 addition & 1 deletion csrc/fastdeploy/backends/ort/ort_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,6 @@ class OrtBackend : public BaseBackend {
Ort::CustomOpDomain custom_op_domain_ = Ort::CustomOpDomain("Paddle");
#endif
OrtBackendOption option_;
void CopyToCpu(const Ort::Value& value, FDTensor* tensor);
void CopyToCpu(const Ort::Value& value, FDTensor* tensor, const std::string& name);
};
} // namespace fastdeploy
8 changes: 2 additions & 6 deletions csrc/fastdeploy/backends/tensorrt/trt_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,8 @@ void TrtBackend::AllocateBufferInDynamicShape(
"Cannot find output: %s of tensorrt network from the original model.",
outputs_desc_[i].name.c_str());
auto ori_idx = iter->second;
(*outputs)[ori_idx].dtype = GetFDDataType(outputs_desc_[i].dtype);
(*outputs)[ori_idx].shape.assign(output_dims.d,
output_dims.d + output_dims.nbDims);
(*outputs)[ori_idx].name = outputs_desc_[i].name;
(*outputs)[ori_idx].data.resize(Volume(output_dims) *
TrtDataTypeSize(outputs_desc_[i].dtype));
std::vector<int64_t> shape(output_dims.d, output_dims.d + output_dims.nbDims);
(*outputs)[ori_idx].Allocate(shape, GetFDDataType(outputs_desc_[i].dtype), outputs_desc_[i].name);
if ((*outputs)[ori_idx].Nbytes() >
outputs_buffer_[outputs_desc_[i].name].nbBytes()) {
outputs_buffer_[outputs_desc_[i].name].resize(output_dims);
Expand Down
6 changes: 3 additions & 3 deletions csrc/fastdeploy/pybind/fastdeploy_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ void BindRuntime(pybind11::module& m) {
// TODO(jiangjiajun) Maybe skip memory copy is a better choice
// use SetExternalData
inputs[index].data.resize(iter->second.nbytes());
memcpy(inputs[index].data.data(), iter->second.mutable_data(),
memcpy(inputs[index].MutableData(), iter->second.mutable_data(),
iter->second.nbytes());
inputs[index].name = iter->first;
index += 1;
Expand All @@ -94,7 +94,7 @@ void BindRuntime(pybind11::module& m) {
auto numpy_dtype = FDDataTypeToNumpyDataType(outputs[i].dtype);
results.emplace_back(
pybind11::array(numpy_dtype, outputs[i].shape));
memcpy(results[i].mutable_data(), outputs[i].data.data(),
memcpy(results[i].mutable_data(), outputs[i].Data(),
outputs[i].Numel() * FDDataTypeSize(outputs[i].dtype));
}
return results;
Expand Down Expand Up @@ -134,4 +134,4 @@ void BindRuntime(pybind11::module& m) {
m.def("get_available_backends", []() { return GetAvailableBackends(); });
}

} // namespace fastdeploy
} // namespace fastdeploy
2 changes: 1 addition & 1 deletion csrc/fastdeploy/pybind/main.cc.in
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
tensor->external_data_ptr = pyarray.mutable_data();
} else {
tensor->data.resize(pyarray.nbytes());
memcpy(tensor->data.data(), pyarray.mutable_data(), pyarray.nbytes());
memcpy(tensor->MutableData(), pyarray.mutable_data(), pyarray.nbytes());
}
}

Expand Down
15 changes: 7 additions & 8 deletions csrc/fastdeploy/pybind/main.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ pybind11::array TensorToPyArray(const FDTensor& tensor);
cv::Mat PyArrayToCvMat(pybind11::array& pyarray);
#endif

template <typename T>
FDDataType CTypeToFDDataType() {
template <typename T> FDDataType CTypeToFDDataType() {
if (std::is_same<T, int32_t>::value) {
return FDDataType::INT32;
} else if (std::is_same<T, int64_t>::value) {
Expand All @@ -59,17 +58,17 @@ FDDataType CTypeToFDDataType() {
}

template <typename T>
std::vector<pybind11::array> PyBackendInfer(
T& self, const std::vector<std::string>& names,
std::vector<pybind11::array>& data) {
std::vector<pybind11::array>
PyBackendInfer(T& self, const std::vector<std::string>& names,
std::vector<pybind11::array>& data) {
std::vector<FDTensor> inputs(data.size());
for (size_t i = 0; i < data.size(); ++i) {
// TODO(jiangjiajun) here is considered to use user memory directly
inputs[i].dtype = NumpyDataTypeToFDDataType(data[i].dtype());
inputs[i].shape.insert(inputs[i].shape.begin(), data[i].shape(),
data[i].shape() + data[i].ndim());
inputs[i].data.resize(data[i].nbytes());
memcpy(inputs[i].data.data(), data[i].mutable_data(), data[i].nbytes());
memcpy(inputs[i].MutableData(), data[i].mutable_data(), data[i].nbytes());
inputs[i].name = names[i];
}

Expand All @@ -81,10 +80,10 @@ std::vector<pybind11::array> PyBackendInfer(
for (size_t i = 0; i < outputs.size(); ++i) {
auto numpy_dtype = FDDataTypeToNumpyDataType(outputs[i].dtype);
results.emplace_back(pybind11::array(numpy_dtype, outputs[i].shape));
memcpy(results[i].mutable_data(), outputs[i].data.data(),
memcpy(results[i].mutable_data(), outputs[i].Data(),
outputs[i].Numel() * FDDataTypeSize(outputs[i].dtype));
}
return results;
}

} // namespace fastdeploy
} // namespace fastdeploy
8 changes: 4 additions & 4 deletions csrc/fastdeploy/vision/classification/ppcls/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ bool PaddleClasModel::Postprocess(const FDTensor& infer_result,
ClassifyResult* result, int topk) {
int num_classes = infer_result.shape[1];
const float* infer_result_buffer =
reinterpret_cast<const float*>(infer_result.data.data());
reinterpret_cast<const float*>(infer_result.Data());
topk = std::min(num_classes, topk);
result->label_ids =
utils::TopKIndices(infer_result_buffer, num_classes, topk);
Expand Down Expand Up @@ -150,6 +150,6 @@ bool PaddleClasModel::Predict(cv::Mat* im, ClassifyResult* result, int topk) {
return true;
}

} // namespace classification
} // namespace vision
} // namespace fastdeploy
} // namespace classification
} // namespace vision
} // namespace fastdeploy
17 changes: 8 additions & 9 deletions csrc/fastdeploy/vision/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

#pragma once

#include <set>
#include <vector>
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/utils/utils.h"
#include "fastdeploy/vision/common/result.h"
#include <set>
#include <vector>

namespace fastdeploy {
namespace vision {
Expand Down Expand Up @@ -87,8 +87,7 @@ void ArgmaxScoreMap(T infer_result_buffer, SegmentationResult* result,
}
}

template <typename T>
void NCHW2NHWC(FDTensor& infer_result) {
template <typename T> void NCHW2NHWC(FDTensor& infer_result) {
T* infer_result_buffer = reinterpret_cast<T*>(infer_result.MutableData());
int num = infer_result.shape[0];
int channel = infer_result.shape[1];
Expand Down Expand Up @@ -125,13 +124,13 @@ void SortDetectionResult(DetectionResult* output);
void SortDetectionResult(FaceDetectionResult* result);

// L2 Norm / cosine similarity (for face recognition, ...)
FASTDEPLOY_DECL std::vector<float> L2Normalize(
const std::vector<float>& values);
FASTDEPLOY_DECL std::vector<float>
L2Normalize(const std::vector<float>& values);

FASTDEPLOY_DECL float CosineSimilarity(const std::vector<float>& a,
const std::vector<float>& b,
bool normalized = true);

} // namespace utils
} // namespace vision
} // namespace fastdeploy
} // namespace utils
} // namespace vision
} // namespace fastdeploy

0 comments on commit e09ac18

Please sign in to comment.