Skip to content

Commit

Permalink
add nrtr
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Aug 31, 2023
1 parent 32c1f09 commit 92d3cc9
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 77 deletions.
51 changes: 25 additions & 26 deletions configs/rec/rec_mtb_nrtr.yml
Original file line number Diff line number Diff line change
@@ -1,37 +1,34 @@
Global:
use_gpu: True
device: gpu
epoch_num: 21
log_smooth_window: 20
print_batch_step: 10
output_dir: ./output/rec/nrtr/
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
output_dir: ./output/rec/nrtr
eval_epoch_step: [0, 1]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path: ppocr/utils/EN_symbol_dict.txt
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_nrtr.txt
use_tensorboard: false
infer_mode: false
infer_img: doc/imgs_words/en/word_1.png
character_dict_path: &character_dict_path ppocr/utils/EN_symbol_dict.txt
max_text_length: &max_text_length 25
use_space_char: &use_space_char False

Export:
export_dir:
export_shape: [ 1, 1, 32, 100 ]
dynamic_axes: []

Optimizer:
name: Adam
beta1: 0.9
beta2: 0.99
clip_norm: 5.0
lr:
name: Cosine
learning_rate: 0.0005
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0.
lr: 0.0005
weight_decay: 0

LRScheduler:
name: CosineAnnealingLR
warmup_epoch: 2


Architecture:
model_type: rec
Expand All @@ -54,6 +51,8 @@ Loss:

PostProcess:
name: NRTRLabelDecode
character_dict_path: *character_dict_path
use_space_char: *use_space_char

Metric:
name: RecMetric
Expand Down Expand Up @@ -82,7 +81,7 @@ Train:
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation/
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
img_mode: BGR
Expand Down
13 changes: 7 additions & 6 deletions ppocr/modeling/heads/rec_nrtr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def forward_train(self, src, tgt):
logit = self.tgt_word_prj(output)
return logit

def forward(self, src, targets=None):
def forward(self, src, data=None):
"""Take in and process masked source/target sequences.
Args:
src: the sequence to the encoder (required).
Expand All @@ -141,14 +141,15 @@ def forward(self, src, targets=None):
"""

if self.training:
max_len = targets[1].max()
tgt = targets[0][:, :2 + max_len]
return self.forward_train(src, tgt)
max_len = data[1].max()
tgt = data[0][:, :2 + max_len]
res= self.forward_train(src, tgt)
else:
if self.beam_size > 0:
return self.forward_beam(src)
res= self.forward_beam(src)
else:
return self.forward_test(src)
res= self.forward_test(src)
return {'res':res}

def forward_test(self, src):

Expand Down
7 changes: 5 additions & 2 deletions tools/infer/predict_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ def __call__(self, img_list):
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()

preds = self.run(norm_img_batch)[0]
cls_result = self.postprocess_op(preds)
preds = self.run(norm_img_batch)
if len(preds) == 1:
preds = preds[0]

cls_result = self.postprocess_op({'res': preds})
elapse += time.time() - tic

for rno in range(len(cls_result)):
Expand Down
27 changes: 5 additions & 22 deletions tools/infer/predict_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,29 +114,12 @@ def __call__(self, img):
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()

outputs = self.run(img)

preds = {}
if self.det_algorithm == "EAST":
preds['f_geo'] = outputs[0]
preds['f_score'] = outputs[1]
elif self.det_algorithm == 'SAST':
preds['f_border'] = outputs[0]
preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3]
elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
preds['res'] = outputs[0]
elif self.det_algorithm == 'FCE':
for i, output in enumerate(outputs):
preds['level_{}'.format(i)] = output
elif self.det_algorithm == "CT":
preds['maps'] = outputs[0]
preds['score'] = outputs[1]
else:
raise NotImplementedError
preds = self.run(img)

if self.det_algorithm in ['DB', 'PSE', 'DB++']:
preds = preds[0]

post_result = self.postprocess_op(preds, [-1, shape_list])
post_result = self.postprocess_op({'res':preds}, [-1, shape_list])
dt_boxes = post_result[0]['points']

if self.args.det_box_type == 'poly':
Expand Down
16 changes: 12 additions & 4 deletions tools/infer/predict_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from torchocr import Config
from torchocr.postprocess import build_post_process
from torchocr.data import create_operators, transform
from torchocr.utils.logging import get_logger
from torchocr.utils.utility import get_image_file_list, check_and_read
from tools.infer.onnx_engine import ONNXEngine
Expand All @@ -34,6 +35,7 @@ def __init__(self, args):
self.rec_algorithm = args.rec_algorithm

cfg = Config(config_path).cfg
self.ops = create_operators(cfg['Transforms'][1:])
self.postprocess_op = build_post_process(cfg['PostProcess'])

def resize_norm_img(self, img, max_wh_ratio):
Expand Down Expand Up @@ -77,15 +79,21 @@ def __call__(self, img_list):
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
if self.rec_algorithm == 'nrtr':
norm_img = transform({'image':img_list[indices[ino]]}, self.ops)[0]
else:
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()

preds = self.run(norm_img_batch)[0]
rec_result = self.postprocess_op({'res':preds})
preds = self.run(norm_img_batch)

if len(preds) == 1:
preds = preds[0]

rec_result = self.postprocess_op({'res': preds})
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
return rec_res, time.time() - st
Expand Down
9 changes: 3 additions & 6 deletions torchocr/losses/rec_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self,
self.with_all = with_all

def forward(self, pred, batch):

pred = pred['res']
if isinstance(pred, dict): # for ABINet
loss = {}
loss_sum = []
Expand Down Expand Up @@ -53,12 +53,9 @@ def forward(self, pred, batch):
eps = 0.1
n_class = pred.shape[1]
one_hot = F.one_hot(tgt, pred.shape[1])
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (
n_class - 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / ( - 1)
log_prb = F.log_softmax(pred, dim=1)
non_pad_mask = torch.not_equal(
tgt, torch.zeros(
tgt.shape, dtype=tgt.dtype))
non_pad_mask = torch.not_equal(tgt, torch.zeros(tgt.shape, dtype=tgt.dtype, device=tgt.device))
loss = -(one_hot * log_prb).sum(dim=1)
loss = loss.masked_select(non_pad_mask).mean()
else:
Expand Down
3 changes: 2 additions & 1 deletion torchocr/modeling/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ def build_backbone(config, model_type):
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
from .rec_resnet_31 import ResNet31
from .rec_nrtr_mtb import MTB
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_lcnetv3 import PPLCNetV3
from .rec_hgnet import PPHGNet_small
support_dict = [
'MobileNetV1Enhance', 'ResNet31', 'MobileNetV3', 'PPLCNetV3', 'PPHGNet_small', 'ResNet'
'MobileNetV1Enhance', 'ResNet31', 'MobileNetV3', 'PPLCNetV3', 'PPHGNet_small', 'ResNet', 'MTB'
]
else:
raise NotImplementedError
Expand Down
33 changes: 33 additions & 0 deletions torchocr/modeling/backbones/rec_nrtr_mtb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from torch import nn


class MTB(nn.Module):
def __init__(self, cnn_num, in_channels):
super(MTB, self).__init__()
self.block = nn.Sequential()
self.out_channels = in_channels
self.cnn_num = cnn_num
if self.cnn_num == 2:
for i in range(self.cnn_num):
self.block.add_module(
'conv_{}'.format(i),
nn.Conv2d(
in_channels=in_channels
if i == 0 else 32 * (2**(i - 1)),
out_channels=32 * (2**i),
kernel_size=3,
stride=2,
padding=1))
self.block.add_module('relu_{}'.format(i), nn.ReLU())
self.block.add_module('bn_{}'.format(i), nn.BatchNorm2d(32 * (2**i)))


def forward(self, images):
x = self.block(images)
if self.cnn_num == 2:
# (b, w, h, c)
x = x.permute(0, 3, 2, 1)
x_shape = x.shape
x = torch.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3]))
return x
3 changes: 2 additions & 1 deletion torchocr/modeling/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ def build_head(config):
from .rec_multi_head import MultiHead
from .rec_sar_head import SARHead
from .rec_att_head import AttentionHead
from .rec_nrtr_head import Transformer
# cls head
from .cls_head import ClsHead

support_dict = [
'MultiHead', 'SARHead', 'DBHead', 'CTCHead', 'ClsHead', 'PFHeadLocal', 'AttentionHead'
'MultiHead', 'SARHead', 'DBHead', 'CTCHead', 'ClsHead', 'PFHeadLocal', 'AttentionHead', 'Transformer'
]

module_name = config.pop('name')
Expand Down
17 changes: 9 additions & 8 deletions torchocr/modeling/heads/rec_nrtr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def forward_train(self, src, tgt):
logit = self.tgt_word_prj(output)
return logit

def forward(self, src, targets=None):
def forward(self, src, data=None):
"""Take in and process masked source/target sequences.
Args:
src: the sequence to the encoder (required).
Expand All @@ -119,14 +119,15 @@ def forward(self, src, targets=None):
"""

if self.training:
max_len = targets[1].max()
tgt = targets[0][:, :2 + max_len]
return self.forward_train(src, tgt)
max_len = data[1].max()
tgt = data[0][:, :2 + max_len]
res = self.forward_train(src, tgt)
else:
if self.beam_size > 0:
return self.forward_beam(src)
res = self.forward_beam(src)
else:
return self.forward_test(src)
res = self.forward_test(src)
return {'res': res}

def forward_test(self, src):

Expand All @@ -138,8 +139,8 @@ def forward_test(self, src):
memory = src # B N C
else:
memory = src
dec_seq = torch.full((bs, 1), 2, dtype=torch.int64)
dec_prob = torch.full((bs, 1), 1., dtype=torch.float32)
dec_seq = torch.full((bs, 1), 2, dtype=torch.int64, device=src.device)
dec_prob = torch.full((bs, 1), 1., dtype=torch.float32, device=src.device)
for len_dec_seq in range(1, self.max_len):
dec_seq_embed = self.embedding(dec_seq)
dec_seq_embed = self.positional_encoding(dec_seq_embed)
Expand Down
2 changes: 1 addition & 1 deletion torchocr/postprocess/rec_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
use_space_char)

def __call__(self, preds, batch=None, *args, **kwargs):

preds = preds['res']
if len(preds) == 2:
preds_id = preds[0]
preds_prob = preds[1]
Expand Down

0 comments on commit 92d3cc9

Please sign in to comment.