Skip to content

Commit

Permalink
新增transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
Bourne-M committed Jan 28, 2022
1 parent 6114dd6 commit 10ac864
Show file tree
Hide file tree
Showing 5 changed files with 668 additions and 18 deletions.
11 changes: 10 additions & 1 deletion config/cfg_det_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,20 @@
'lr': 0.001,
'weight_decay': 1e-4,
}
# backbone设置为swin_transformer时使用
# config.optimizer = {
# 'type': 'AdamW',
# 'lr': 0.0001,
# 'betas': (0.9, 0.999),
# 'weight_decay': 0.05,
# }


config.model = {
# backbone 可以设置'pretrained': False/True
'type': "DetModel",
'backbone': {"type": "ResNet", 'layers': 18, 'pretrained': True}, # ResNet or MobileNetV3
'backbone': {"type": "ResNet", 'layers': 18, 'pretrained': True}, # ResNet or MobileNetV3
# 'backbone': {"type": "Transformer", 'pretrained': True},#swin_transformer
'neck': {"type": 'DB_fpn', 'out_channels': 256},
'head': {"type": "DBHead"},
'in_channels': 3,
Expand Down
20 changes: 7 additions & 13 deletions torchocr/datasets/DetDataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class JsonDataset(Dataset):
"""
from https://github.com/WenmuZhou/OCR_DataSet/blob/master/dataset/det.py
"""

def __init__(self, config):
assert config.img_mode in ['RGB', 'BRG', 'GRAY']
self.ignore_tags = config.ignore_tags
Expand Down Expand Up @@ -72,18 +73,10 @@ def load_data(self, path: str) -> list:
for annotation in gt['annotations']:
if len(annotation['polygon']) == 0 or len(annotation['text']) == 0:
continue
if len(annotation['polygon']) != 4:
a = np.array(annotation['polygon'], dtype=np.int32)
x, y, w, h = cv2.boundingRect(a)
polygons.append([[x, y], [x + w, y], [x + w, y + h], [x, y + h]])
texts.append('ignore')
illegibility_list.append(True)
language_list.append(annotation['language'])
else:
polygons.append(annotation['polygon'])
texts.append(annotation['text'])
illegibility_list.append(annotation['illegibility'])
language_list.append(annotation['language'])
polygons.append(annotation['polygon'])
texts.append(annotation['text'])
illegibility_list.append(annotation['illegibility'])
language_list.append(annotation['language'])
if self.load_char_annotation:
for char_annotation in annotation['chars']:
if len(char_annotation['polygon']) == 0 or len(char_annotation['char']) == 0:
Expand All @@ -92,7 +85,7 @@ def load_data(self, path: str) -> list:
texts.append(char_annotation['char'])
illegibility_list.append(char_annotation['illegibility'])
language_list.append(char_annotation['language'])
data_list.append({'img_path': img_path, 'img_name': gt['img_name'], 'text_polys': np.array(polygons),
data_list.append({'img_path': img_path, 'img_name': gt['img_name'], 'text_polys': polygons,
'texts': texts, 'ignore_tags': illegibility_list})
except:
print(f'error gt:{img_path}')
Expand Down Expand Up @@ -142,6 +135,7 @@ def __len__(self):
from torchocr.utils import show_img, draw_bbox

from matplotlib import pyplot as plt

dataset = JsonDataset(config.dataset.train.dataset)
train_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=0)
for i, data in enumerate(tqdm(train_loader)):
Expand Down
6 changes: 3 additions & 3 deletions torchocr/datasets/det_modules/iaa_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ def may_augment_annotation(self, aug, data, shape):
line_polys = []
for poly in data['text_polys']:
new_poly = self.may_augment_poly(aug, shape, poly)
line_polys.append(new_poly)
data['text_polys'] = np.array(line_polys)
line_polys.append(np.array(new_poly))
data['text_polys'] = line_polys
return data

def may_augment_poly(self, aug, img_shape, poly):
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
keypoints = aug.augment_keypoints(
[imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints
poly = [(p.x, p.y) for p in keypoints]
return poly
return poly
3 changes: 2 additions & 1 deletion torchocr/networks/architectures/DetModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from torchocr.networks.heads.DetDbHead import DBHead
from torchocr.networks.heads.DetPseHead import PseHead
from torchocr.networks.backbones.DetGhostNet import GhostNet
from torchocr.networks.backbones.Transformer import *

backbone_dict = {'MobileNetV3': MobileNetV3, 'ResNet': ResNet,'GhostNet':GhostNet}
backbone_dict = {'MobileNetV3': MobileNetV3, 'ResNet': ResNet, 'GhostNet': GhostNet, 'Transformer': SwinTransformer}
neck_dict = {'DB_fpn': DB_fpn, 'pse_fpn': PSEFpn}
head_dict = {'DBHead': DBHead, 'PseHead': PseHead}

Expand Down
Loading

0 comments on commit 10ac864

Please sign in to comment.