Skip to content

Commit

Permalink
[benchmark] fixed paddlex benchmark for picodet 320 (PaddlePaddle#2046)
Browse files Browse the repository at this point in the history
  • Loading branch information
DefTruth authored Jun 20, 2023
1 parent 7191d2d commit 1144e0a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
2 changes: 1 addition & 1 deletion benchmark/paddlex/benchmark_gpu_trt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fi
# PaddleDetection
./benchmark_ppdet --model PP-YOLOE+_crn_l_80e --image ppdet_det_img.jpg --config_path $CONFIG_PATH
./benchmark_ppdet --model rt_detr_hgnetv2_l --image ppdet_det_img.jpg --config_path $CONFIG_PATH
./benchmark_ppdet --model PP-PicoDet_s_320_lcnet --image ppdet_det_img.jpg --config_path $CONFIG_PATH
./benchmark_ppdet --model PP-PicoDet_s_320_lcnet --image ppdet_det_img.jpg --config_path $CONFIG_PATH --trt_shape 1,3,320,320:1,3,320,320:1,3,320,320
./benchmark_ppdet --model dino_r50_4scale --image ppdet_det_img.jpg --config_path $CONFIG_PATH

# PaddleSeg
Expand Down
12 changes: 10 additions & 2 deletions benchmark/paddlex/benchmark_ppdet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ namespace vision = fastdeploy::vision;
namespace benchmark = fastdeploy::benchmark;

DEFINE_bool(no_nms, false, "Whether the model contains nms.");
DEFINE_string(trt_shape, "1,3,640,640:1,3,640,640:1,3,640,640",
"Set min/opt/max shape for trt/paddle_trt backend."
"eg:--trt_shape 1,3,640,640:1,3,640,640:1,3,640,640");
DEFINE_string(input_name, "image",
"Set input name for trt/paddle_trt backend."
"eg:--input_names x");

int main(int argc, char* argv[]) {
#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION)
Expand All @@ -46,8 +52,10 @@ int main(int argc, char* argv[]) {
}
if (config_info["backend"] == "paddle_trt" ||
config_info["backend"] == "trt") {
option.trt_option.SetShape("image", {1, 3, 640, 640}, {1, 3, 640, 640},
{1, 3, 640, 640});
std::vector<std::vector<int32_t>> trt_shapes =
benchmark::ResultManager::GetInputShapes(FLAGS_trt_shape);
option.trt_option.SetShape(FLAGS_input_name, trt_shapes[0], trt_shapes[1],
trt_shapes[2]);
option.trt_option.SetShape("scale_factor", {1, 2}, {1, 2}, {1, 2});
option.trt_option.SetShape("im_shape", {1, 2}, {1, 2}, {1, 2});
}
Expand Down

0 comments on commit 1144e0a

Please sign in to comment.