Skip to content

Commit

Permalink
更新Preprocess代码
Browse files Browse the repository at this point in the history
  • Loading branch information
Zheng-Bicheng committed Feb 15, 2023
1 parent c948b72 commit 4ccfbea
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
19 changes: 12 additions & 7 deletions fastdeploy/vision/keypointdet/pptinypose/pptinypose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ PPTinyPose::PPTinyPose(const std::string& model_file,
Backend::LITE};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
valid_kunlunxin_backends = {Backend::LITE};
valid_rknpu_backends = {Backend::RKNPU2};
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
Expand Down Expand Up @@ -66,14 +67,18 @@ bool PPTinyPose::BuildPreprocessPipelineFromConfig() {
for (const auto& op : cfg["Preprocess"]) {
std::string op_name = op["type"].as<std::string>();
if (op_name == "NormalizeImage") {
auto mean = op["mean"].as<std::vector<float>>();
auto std = op["std"].as<std::vector<float>>();
bool is_scale = op["is_scale"].as<bool>();
processors_.push_back(std::make_shared<Normalize>(mean, std, is_scale));
if (!disable_normalize_) {
auto mean = op["mean"].as<std::vector<float>>();
auto std = op["std"].as<std::vector<float>>();
bool is_scale = op["is_scale"].as<bool>();
processors_.push_back(std::make_shared<Normalize>(mean, std, is_scale));
}
} else if (op_name == "Permute") {
// permute = cast<float> + HWC2CHW
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(std::make_shared<HWC2CHW>());
if (!disable_permute_) {
// permute = cast<float> + HWC2CHW
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(std::make_shared<HWC2CHW>());
}
} else if (op_name == "TopDownEvalAffine") {
auto trainsize = op["trainsize"].as<std::vector<int>>();
int height = trainsize[1];
Expand Down
18 changes: 17 additions & 1 deletion fastdeploy/vision/keypointdet/pptinypose/pptinypose.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace keypointdetection {
*/
class FASTDEPLOY_DECL PPTinyPose : public FastDeployModel {
public:
/** \brief Set path of model file and configuration file, and the configuration of runtime
/** \brief Set path of model file and configuration file, and the configuration of runtime
*
* \param[in] model_file Path of model file, e.g pptinypose/model.pdmodel
* \param[in] params_file Path of parameter file, e.g pptinypose/model.pdiparams, if the model format is ONNX, this parameter will be ignored
Expand Down Expand Up @@ -68,6 +68,18 @@ class FASTDEPLOY_DECL PPTinyPose : public FastDeployModel {
*/
bool use_dark = true;

/// This function will disable normalize in preprocessing step.
void DisableNormalize() {
disable_normalize_ = true;
BuildPreprocessPipelineFromConfig();
}

/// This function will disable hwc2chw in preprocessing step.
void DisablePermute() {
disable_permute_ = true;
BuildPreprocessPipelineFromConfig();
}

protected:
bool Initialize();
/// Build the preprocess pipeline from the loaded model
Expand All @@ -84,6 +96,10 @@ class FASTDEPLOY_DECL PPTinyPose : public FastDeployModel {
private:
std::vector<std::shared_ptr<Processor>> processors_;
std::string config_file_;
// for recording the switch of hwc2chw
bool disable_permute_ = false;
// for recording the switch of normalize
bool disable_normalize_ = false;
};
} // namespace keypointdetection
} // namespace vision
Expand Down

0 comments on commit 4ccfbea

Please sign in to comment.