Skip to content

Commit

Permalink
add table cpp infer
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Aug 10, 2022
1 parent 97f7f74 commit 3867c8c
Show file tree
Hide file tree
Showing 19 changed files with 940 additions and 71 deletions.
9 changes: 8 additions & 1 deletion deploy/cpp_infer/include/args.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ DECLARE_string(image_dir);
DECLARE_string(type);
// detection related
DECLARE_string(det_model_dir);
DECLARE_int32(max_side_len);
DECLARE_string(limit_type);
DECLARE_int32(limit_side_len);
DECLARE_double(det_db_thresh);
DECLARE_double(det_db_box_thresh);
DECLARE_double(det_db_unclip_ratio);
Expand All @@ -48,7 +49,13 @@ DECLARE_int32(rec_batch_num);
DECLARE_string(rec_char_dict_path);
DECLARE_int32(rec_img_h);
DECLARE_int32(rec_img_w);
// structure model related
DECLARE_string(table_model_dir);
DECLARE_int32(table_max_len);
DECLARE_int32(table_batch_num);
DECLARE_string(table_char_dict_path);
// forward related
DECLARE_bool(det);
DECLARE_bool(rec);
DECLARE_bool(cls);
DECLARE_bool(table);
12 changes: 7 additions & 5 deletions deploy/cpp_infer/include/ocr_det.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class DBDetector {
explicit DBDetector(const std::string &model_dir, const bool &use_gpu,
const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const int &max_side_len,
const double &det_db_thresh,
const bool &use_mkldnn, const string &limit_type,
const int &limit_side_len, const double &det_db_thresh,
const double &det_db_box_thresh,
const double &det_db_unclip_ratio,
const std::string &det_db_score_mode,
Expand All @@ -54,7 +54,8 @@ class DBDetector {
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
this->use_mkldnn_ = use_mkldnn;

this->max_side_len_ = max_side_len;
this->limit_type_ = limit_type;
this->limit_side_len_ = limit_side_len;

this->det_db_thresh_ = det_db_thresh;
this->det_db_box_thresh_ = det_db_box_thresh;
Expand Down Expand Up @@ -84,7 +85,8 @@ class DBDetector {
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;

int max_side_len_ = 960;
string limit_type_ = "max";
int limit_side_len_ = 960;

double det_db_thresh_ = 0.3;
double det_db_box_thresh_ = 0.5;
Expand All @@ -106,7 +108,7 @@ class DBDetector {
Permute permute_op_;

// post-process
PostProcessor post_processor_;
DBPostProcessor post_processor_;
};

} // namespace PaddleOCR
11 changes: 6 additions & 5 deletions deploy/cpp_infer/include/paddleocr.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@ class PPOCR {
ocr(std::vector<cv::String> cv_all_img_names, bool det = true,
bool rec = true, bool cls = true);

private:
DBDetector *detector_ = nullptr;
Classifier *classifier_ = nullptr;
CRNNRecognizer *recognizer_ = nullptr;

protected:
void det(cv::Mat img, std::vector<OCRPredictResult> &ocr_results,
std::vector<double> &times);
void rec(std::vector<cv::Mat> img_list,
Expand All @@ -62,6 +58,11 @@ class PPOCR {
std::vector<double> &times);
void log(std::vector<double> &det_times, std::vector<double> &rec_times,
std::vector<double> &cls_times, int img_num);

private:
DBDetector *detector_ = nullptr;
Classifier *classifier_ = nullptr;
CRNNRecognizer *recognizer_ = nullptr;
};

} // namespace PaddleOCR
79 changes: 79 additions & 0 deletions deploy/cpp_infer/include/paddlestructure.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>

#include <cstring>
#include <fstream>
#include <numeric>

#include <include/paddleocr.h>
#include <include/preprocess_op.h>
#include <include/structure_table.h>
#include <include/utility.h>

using namespace paddle_infer;

namespace PaddleOCR {

class PaddleStructure : public PPOCR {
public:
explicit PaddleStructure();
~PaddleStructure();
std::vector<std::vector<StructurePredictResult>>
structure(std::vector<cv::String> cv_all_img_names, bool layout = false,
bool table = true);

private:
StructureTableRecognizer *recognizer_ = nullptr;

void table(cv::Mat img, StructurePredictResult &structure_result,
std::vector<double> &time_info_table,
std::vector<double> &time_info_det,
std::vector<double> &time_info_rec,
std::vector<double> &time_info_cls);
std::string
rebuild_table(std::vector<std::string> rec_html_tags,
std::vector<std::vector<std::vector<int>>> rec_boxes,
std::vector<OCRPredictResult> &ocr_result);

float iou(std::vector<std::vector<int>> &box1,
std::vector<std::vector<int>> &box2);
float dis(std::vector<std::vector<int>> &box1,
std::vector<std::vector<int>> &box2);

static bool comparison_dis(const std::vector<float> &dis1,
const std::vector<float> &dis2) {
if (dis1[1] < dis2[1]) {
return true;
} else if (dis1[1] == dis2[1]) {
return dis1[0] < dis2[0];
} else {
return false;
}
}
};

} // namespace PaddleOCR
19 changes: 18 additions & 1 deletion deploy/cpp_infer/include/postprocess_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ using namespace std;

namespace PaddleOCR {

class PostProcessor {
class DBPostProcessor {
public:
void GetContourArea(const std::vector<std::vector<float>> &box,
float unclip_ratio, float &distance);
Expand Down Expand Up @@ -90,4 +90,21 @@ class PostProcessor {
}
};

class TablePostProcessor {
public:
void init(std::string label_path);
void
Run(std::vector<float> &loc_preds, std::vector<float> &structure_probs,
std::vector<float> &rec_scores, std::vector<int> &loc_preds_shape,
std::vector<int> &structure_probs_shape,
std::vector<std::vector<std::string>> &rec_html_tag_batch,
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes_batch,
std::vector<int> &width_list, std::vector<int> &height_list);

private:
std::vector<std::string> label_list_;
std::string end = "eos";
std::string beg = "sos";
};

} // namespace PaddleOCR
19 changes: 16 additions & 3 deletions deploy/cpp_infer/include/preprocess_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ class PermuteBatch {
public:
virtual void Run(const std::vector<cv::Mat> imgs, float *data);
};

class ResizeImgType0 {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, int max_size_len,
float &ratio_h, float &ratio_w, bool use_tensorrt);
virtual void Run(const cv::Mat &img, cv::Mat &resize_img, string limit_type,
int limit_side_len, float &ratio_h, float &ratio_w,
bool use_tensorrt);
};

class CrnnResizeImg {
Expand All @@ -69,4 +70,16 @@ class ClsResizeImg {
const std::vector<int> &rec_image_shape = {3, 48, 192});
};

class TableResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
const int max_len = 488);
};

class TablePadImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
const int max_len = 488);
};

} // namespace PaddleOCR
100 changes: 100 additions & 0 deletions deploy/cpp_infer/include/structure_table.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>

#include <cstring>
#include <fstream>
#include <numeric>

#include <include/postprocess_op.h>
#include <include/preprocess_op.h>
#include <include/utility.h>

using namespace paddle_infer;

namespace PaddleOCR {

class StructureTableRecognizer {
public:
explicit StructureTableRecognizer(
const std::string &model_dir, const bool &use_gpu, const int &gpu_id,
const int &gpu_mem, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const string &label_path,
const bool &use_tensorrt, const std::string &precision,
const int &table_batch_num, const int &table_max_len) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem;
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
this->use_mkldnn_ = use_mkldnn;
this->use_tensorrt_ = use_tensorrt;
this->precision_ = precision;
this->table_batch_num_ = table_batch_num;
this->table_max_len_ = table_max_len;

this->post_processor_.init(label_path);
LoadModel(model_dir);
}

// Load Paddle inference model
void LoadModel(const std::string &model_dir);

void Run(std::vector<cv::Mat> img_list,
std::vector<std::vector<std::string>> &rec_html_tags,
std::vector<float> &rec_scores,
std::vector<std::vector<std::vector<std::vector<int>>>> &rec_boxes,
std::vector<double> &times);

private:
std::shared_ptr<Predictor> predictor_;

bool use_gpu_ = false;
int gpu_id_ = 0;
int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
int table_max_len_ = 488;

std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
bool is_scale_ = true;

bool use_tensorrt_ = false;
std::string precision_ = "fp32";
int table_batch_num_ = 1;

// pre-process
TableResizeImg resize_op_;
Normalize normalize_op_;
PermuteBatch permute_op_;
TablePadImg pad_op_;

// post-process
TablePostProcessor post_processor_;

}; // class StructureTableRecognizer

} // namespace PaddleOCR
24 changes: 24 additions & 0 deletions deploy/cpp_infer/include/utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ struct OCRPredictResult {
int cls_label = -1;
};

struct StructurePredictResult {
std::vector<int> box;
std::string type;
std::vector<OCRPredictResult> text_res;
std::string html;
float html_score = -1;
};

class Utility {
public:
static std::vector<std::string> ReadDict(const std::string &path);
Expand Down Expand Up @@ -68,6 +76,22 @@ class Utility {
static void CreateDir(const std::string &path);

static void print_result(const std::vector<OCRPredictResult> &ocr_result);

static cv::Mat crop_image(cv::Mat &img, std::vector<int> &area);

static void sorted_boxes(std::vector<OCRPredictResult> &ocr_result);

private:
static bool comparison_box(const OCRPredictResult &result1,
const OCRPredictResult &result2) {
if (result1.box[0][1] < result2.box[0][1]) {
return true;
} else if (result1.box[0][1] == result2.box[0][1]) {
return result1.box[0][0] < result2.box[0][0];
} else {
return false;
}
}
};

} // namespace PaddleOCR
Loading

0 comments on commit 3867c8c

Please sign in to comment.