Skip to content

Commit

Permalink
add sar
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Aug 31, 2023
1 parent ad0e4d6 commit 32c1f09
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 45 deletions.
63 changes: 31 additions & 32 deletions configs/rec/rec_r31_sar.yml
Original file line number Diff line number Diff line change
@@ -1,37 +1,35 @@
Global:
use_gpu: true
device: gpu
epoch_num: 5
log_smooth_window: 20
print_batch_step: 20
output_dir: ./sar_rec
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
print_batch_step: 10
output_dir: ./output/rec/sar
eval_epoch_step: [0, 1]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img:
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
max_text_length: 30
infer_mode: False
use_space_char: False
rm_symbol: True
save_res_path: ./output/rec/predicts_sar.txt
checkpoints:
use_tensorboard: false
infer_mode: false
infer_img: doc/imgs_words/en/word_1.png
character_dict_path: &character_dict_path ppocr/utils/dict90.txt
max_text_length: &max_text_length 30
use_space_char: &use_space_char False
rm_symbol: &rm_symbol True

Export:
export_dir:
export_shape: [ 1, 3, 48, 160 ]
dynamic_axes: []

Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs: [3, 4]
values: [0.001, 0.0001, 0.00001]
regularizer:
name: 'L2'
factor: 0
lr: 0.001
weight_decay: 0

LRScheduler:
name: MultiStepLR
milestones: [3,4]
warmup_epoch: 0

Architecture:
model_type: rec
Expand All @@ -47,17 +45,18 @@ Loss:

PostProcess:
name: SARLabelDecode
character_dict_path: *character_dict_path
use_space_char: *use_space_char
rm_symbol: *rm_symbol

Metric:
name: RecMetric


Train:
dataset:
name: SimpleDataSet
label_file_list: ['./train_data/train_list.txt']
data_dir: ./train_data/
ratio_list: 1.0
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
img_mode: BGR
Expand All @@ -77,7 +76,7 @@ Train:
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation/
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
img_mode: BGR
Expand Down
12 changes: 6 additions & 6 deletions ppocr/modeling/heads/rec_sar_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,22 +389,22 @@ def __init__(self,
max_text_length=max_text_length,
pred_concat=pred_concat)

def forward(self, feat, targets=None):
def forward(self, feat, data=None):
'''
img_metas: [label, valid_ratio]
'''
holistic_feat = self.encoder(feat, targets) # bsz c
holistic_feat = self.encoder(feat, data) # bsz c
if self.training:
label = targets[0] # label
label = data[0] # label
final_out = self.decoder(
feat, holistic_feat, label, img_metas=targets)
feat, holistic_feat, label, img_metas=data)
else:
final_out = self.decoder(
feat,
holistic_feat,
label=None,
img_metas=targets,
img_metas=data,
train_mode=False)
# (bsz, seq_len, num_classes)

return final_out
return {'res':final_out}
1 change: 1 addition & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
| rec_r34_vd_tps_bilstm_ctc | N | 0.0035705192 | [config](configs/rec/rec_r34_vd_tps_bilstm_ctc.yml) |
| rec_mv3_tps_bilstm_att | Y | 1.8528418e-09 | [config](configs/rec/rec_mv3_tps_bilstm_att.yml) |
| rec_r34_vd_tps_bilstm_att | N | 0.0006942689 | [config](configs/rec/rec_r34_vd_tps_bilstm_att.yml) |
| rec_r31_sar | Y | 7.348353e-08 | [config](configs/rec/rec_r31_sar.yml) |

## TODO

Expand Down
2 changes: 1 addition & 1 deletion tools/infer/predict_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __call__(self, img_list):
norm_img_batch = norm_img_batch.copy()

preds = self.run(norm_img_batch)[0]
rec_result = self.postprocess_op(preds)
rec_result = self.postprocess_op({'res':preds})
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
return rec_res, time.time() - st
Expand Down
4 changes: 2 additions & 2 deletions torchocr/modeling/heads/rec_multi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def forward(self, x, data=None):
if not self.training:
return {'res': ctc_out}
if self.gtc_head == 'sar':
sar_out = self.sar_head(x, data[1:])
sar_out = self.sar_head(x, data[1:])['res']
head_out['sar'] = sar_out
else:
gtc_out = self.gtc_head(self.before_gtc(x), data[1:])
gtc_out = self.gtc_head(self.before_gtc(x), data[1:])['res']
head_out['nrtr'] = gtc_out
return head_out
4 changes: 2 additions & 2 deletions torchocr/modeling/heads/rec_sar_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def forward(self, feat, img_metas=None):
if img_metas is not None and self.mask:
valid_ratios = img_metas[-1]

h_feat = feat.shape[2] # bsz c h w
h_feat = feat.size(2) # bsz c h w
feat_v = F.max_pool2d(
feat, kernel_size=(h_feat, 1), stride=1, padding=0)

Expand Down Expand Up @@ -378,4 +378,4 @@ def forward(self, feat, data=None):
train_mode=False)
# (bsz, seq_len, num_classes)

return final_out
return {'res': final_out}
4 changes: 2 additions & 2 deletions torchocr/postprocess/rec_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)

def __call__(self, preds, batch=None, *args, **kwargs):
if 'res' in preds:
preds = preds['res']
preds = preds['res']
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()
preds_idx = preds.argmax(axis=2)
Expand Down Expand Up @@ -536,6 +535,7 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
return result_list

def __call__(self, preds, batch=None, *args, **kwargs):
preds = preds['res']
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()
preds_idx = preds.argmax(axis=2)
Expand Down

0 comments on commit 32c1f09

Please sign in to comment.