Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#7266 from WenmuZhou/tipc
Browse files Browse the repository at this point in the history
add PP-OCRv2 det amp custom_black_list
  • Loading branch information
WenmuZhou authored Aug 19, 2022
2 parents 1411e80 + 0a247f0 commit 94710ae
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
3 changes: 3 additions & 0 deletions configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ Global:
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt
use_amp: False
amp_level: O2
amp_custom_black_list: ['exp']

Architecture:
name: DistillationModel
Expand Down
3 changes: 2 additions & 1 deletion tools/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ def train(config,
model_average = True
# use amp
if scaler:
with paddle.amp.auto_cast(level=amp_level):
custom_black_list = config['Global'].get('amp_custom_black_list',[])
with paddle.amp.auto_cast(level=amp_level, custom_black_list=custom_black_list):
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']:
Expand Down

0 comments on commit 94710ae

Please sign in to comment.