Skip to content

Commit

Permalink
fix infer_rec for other input (PaddlePaddle#7286)
Browse files Browse the repository at this point in the history
  • Loading branch information
tink2123 authored Aug 23, 2022
1 parent 1dbff73 commit 5a52480
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion doc/doc_ch/algorithm_rec_srn.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ python3 tools/export_model.py -c configs/rec/rec_r50_fpn_srn.yml -o Global.pretr
SRN文本识别模型推理,可以执行如下命令:

```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_srn/" --rec_image_shape="1,64,256" --rec_char_type="ch" --rec_algorithm="SRN" --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --use_space_char=False
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_srn/" --rec_image_shape="1,64,256" --rec_algorithm="SRN" --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --use_space_char=False
```

<a name="4-2"></a>
Expand Down
13 changes: 7 additions & 6 deletions tools/infer/predict_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,13 @@ def __call__(self, img_list):
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
if self.rec_algorithm == "SRN":
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
if self.rec_algorithm == "SAR":
valid_ratios = []
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
# max_wh_ratio = 0
Expand All @@ -357,22 +364,16 @@ def __call__(self, img_list):
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):

if self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
valid_ratio = np.expand_dims(valid_ratio, axis=0)
valid_ratios = []
valid_ratios.append(valid_ratio)
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "SRN":
norm_img = self.process_image_srn(
img_list[indices[ino]], self.rec_image_shape, 8, 25)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
Expand Down

0 comments on commit 5a52480

Please sign in to comment.