Skip to content

Commit

Permalink
add aster rosetta
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Aug 30, 2023
1 parent 22b22a3 commit c6efc38
Show file tree
Hide file tree
Showing 10 changed files with 239 additions and 55 deletions.
2 changes: 1 addition & 1 deletion configs/rec/rec_mv3_none_bilstm_ctc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Optimizer:
weight_decay: 0.0

LRScheduler:
name: CosineAnnealingLR
name: ConstLR
warmup_epoch: 0

Architecture:
Expand Down
46 changes: 24 additions & 22 deletions configs/rec/rec_mv3_none_none_ctc.yml
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
Global:
use_gpu: True
device: gpu
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
output_dir: ./output/rec/mv3_none_none_ctc/
save_epoch_step: 3
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
output_dir: ./output/rec/mv3_none_none_ctc
eval_epoch_step: [0, 1]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_mv3_none_none_ctc.txt
use_tensorboard: false
infer_mode: false
infer_img: doc/imgs_words/en/word_1.png
character_dict_path: &character_dict_path
max_text_length: &max_text_length 25
use_space_char: &use_space_char False

Export:
export_dir:
export_shape: [ 1, 3, 32, 100 ]
dynamic_axes: [ 0, 2, 3 ]

Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
learning_rate: 0.0005
regularizer:
name: 'L2'
factor: 0
lr: 0.0005
weight_decay: 0

LRScheduler:
name: ConstLR
warmup_epoch: 0


Architecture:
model_type: rec
Expand All @@ -50,6 +50,8 @@ Loss:

PostProcess:
name: CTCLabelDecode
character_dict_path: *character_dict_path
use_space_char: *use_space_char

Metric:
name: RecMetric
Expand Down
45 changes: 23 additions & 22 deletions configs/rec/rec_mv3_tps_bilstm_att.yml
Original file line number Diff line number Diff line change
@@ -1,35 +1,34 @@
Global:
use_gpu: True
device: gpu
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
output_dir: ./output/rec/rec_mv3_tps_bilstm_att/
save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 2000]
cal_metric_during_train: True
output_dir: ./output/rec/rec_mv3_tps_bilstm_att
eval_epoch_step: [0, 1]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path:
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_mv3_tps_bilstm_att.txt
use_tensorboard: false
infer_mode: false
infer_img: doc/imgs_words/en/word_1.png
character_dict_path: &character_dict_path
max_text_length: &max_text_length 25
use_space_char: &use_space_char False


Export:
export_dir:
export_shape: [ 1, 3, 32, 100 ]
dynamic_axes: [ 0, 2, 3 ]

Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
learning_rate: 0.0005
regularizer:
name: 'L2'
factor: 0.00001
lr: 0.0005
weight_decay: 0.00001

LRScheduler:
name: ConstLR
warmup_epoch: 0

Architecture:
model_type: rec
Expand Down Expand Up @@ -57,6 +56,8 @@ Loss:

PostProcess:
name: AttnLabelDecode
character_dict_path: *character_dict_path
use_space_char: *use_space_char

Metric:
name: RecMetric
Expand Down
4 changes: 2 additions & 2 deletions configs/rec/rec_mv3_tps_bilstm_ctc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Global:
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
output_dir: ./output/rec/mv3_tps_bilstm_ctc
output_dir: ./output/rec/rec_mv3_tps_bilstm_ctc
eval_epoch_step: [0, 1]
cal_metric_during_train: true
pretrained_model:
Expand All @@ -26,7 +26,7 @@ Optimizer:
weight_decay: 0.0

LRScheduler:
name: CosineAnnealingLR
name: ConstLR
warmup_epoch: 0

Architecture:
Expand Down
7 changes: 4 additions & 3 deletions ppocr/modeling/heads/rec_att_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot

def forward(self, inputs, targets=None, batch_max_length=25):
def forward(self, inputs, data=None, batch_max_length=25):
batch_size = paddle.shape(inputs)[0]
num_steps = batch_max_length

hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = []

if targets is not None:
if data is not None:
targets = data[1]
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
Expand Down Expand Up @@ -76,7 +77,7 @@ def forward(self, inputs, targets=None, batch_max_length=25):
targets = next_input
if not self.training:
probs = paddle.nn.functional.softmax(probs, axis=2)
return probs
return {'res':probs}


class AttentionGRUCell(nn.Layer):
Expand Down
2 changes: 2 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@

| 模型 | 是否对齐 | 对齐误差| 配置文件 |
|---|---|---|---|
| rec_mv3_none_none_ctc | X | 2.114354e-09 | [config](configs/rec/rec_mv3_none_none_ctc.yml) |
| rec_mv3_none_bilstm_ctc | X | 1.1861777e-09 | [config](configs/rec/rec_mv3_none_bilstm_ctc.yml) |
| rec_mv3_tps_bilstm_ctc | X | 1.1886948e-09 | [config](configs/rec/rec_mv3_tps_bilstm_ctc.yml) |
| rec_mv3_tps_bilstm_att | X | 1.8528418e-09 | [config](configs/rec/rec_mv3_tps_bilstm_att.yml) |

## TODO

Expand Down
6 changes: 2 additions & 4 deletions torchocr/losses/rec_att_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ def __init__(self, **kwargs):
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')

def forward(self, predicts, batch):
targets = batch[1].astype("int64")
label_lengths = batch[2].astype('int64')
batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[
1], predicts.shape[2]
predicts = predicts['res'][:, :-1]
targets = batch[1].long()[:,1:]
assert len(targets.shape) == len(list(predicts.shape)) - 1, \
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"

Expand Down
3 changes: 2 additions & 1 deletion torchocr/modeling/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ def build_head(config):
from .rec_ctc_head import CTCHead
from .rec_multi_head import MultiHead
from .rec_sar_head import SARHead
from .rec_att_head import AttentionHead
# cls head
from .cls_head import ClsHead

support_dict = [
'MultiHead', 'SARHead', 'DBHead', 'CTCHead', 'ClsHead', 'PFHeadLocal'
'MultiHead', 'SARHead', 'DBHead', 'CTCHead', 'ClsHead', 'PFHeadLocal', 'AttentionHead'
]

module_name = config.pop('name')
Expand Down
Loading

0 comments on commit c6efc38

Please sign in to comment.