Skip to content

Commit

Permalink
Merge pull request WenmuZhou#66 from WenmuZhou/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
WenmuZhou authored Jul 3, 2020
2 parents 33bf76c + 7acc4f7 commit 4629834
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ PytorchOCR开源的文本检测算法列表:

| 模型 | 骨干网络 | precision | recall | Hmean | 下载链接 |
| ---- | ---- | ---- | ---- | ---- | ---- |
|DB|ResNet18_vd|90.56%|72.66%|80.63%|见百度网盘|
|DB|MobileNetV3|84.63%|66.14%|74.23%|见百度网盘|

## 文本识别算法

Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ pyclipper==1.1.0.post3
shapely==1.7.0
torch>=1.4.0
torchvision>=0.5.0
python-Levenshtein
lmdb
imgaug
python-Levenshtein>=0.12.0
lmdb>=0.98
imgaug>=0.4.0
3 changes: 3 additions & 0 deletions torchocr/deprecated/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
# @Time : 2020/7/3 9:12
# @Author : zhoujun
19 changes: 16 additions & 3 deletions torchocr/networks/architectures/DetModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

from torchocr.networks.backbones.DetMobilenetV3 import MobileNetV3
from torchocr.networks.backbones.DetResNetvd import ResNet
from torchocr.networks.necks.FeaturePyramidNetwork import FeaturePyramidNetwork
from torchocr.networks.necks.FPN import FPN
from torchocr.networks.heads.DetDbHead import DBHead

backbone_dict = {'MobileNetV3': MobileNetV3, 'ResNet': ResNet}
neck_dict = {'FPN': FeaturePyramidNetwork}
neck_dict = {'FPN': FPN}
head_dict = {'DBHead': DBHead}


Expand Down Expand Up @@ -41,4 +41,17 @@ def forward(self, x):
x = self.backbone(x)
x = self.neck(x)
x = self.head(x)
return x
return x


if __name__ == '__main__':
import torch

db_config = AttrDict(
in_channels=3,
backbone=AttrDict(type='MobileNetV3', layers=50, model_name='large',pretrained=True),
neck=AttrDict(type='FPN', out_channels=256),
head=AttrDict(type='DBHead')
)
x = torch.zeros(1, 3, 640, 640)
model = DetModel(db_config)
5 changes: 4 additions & 1 deletion torchocr/networks/backbones/DetMobilenetV3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import os
import torch
from torch import nn
Expand Down Expand Up @@ -149,10 +150,12 @@ def __init__(self, in_channels, pretrained=True, **kwargs):

if pretrained:
ckpt_path = f'./weights/MobileNetV3_{model_name}_x{str(scale).replace(".", "_")}.pth'
logger = logging.getLogger('torchocr')
if os.path.exists(ckpt_path):
logger.info('load imagenet weights')
self.load_state_dict(torch.load(ckpt_path))
else:
print(f'{ckpt_path} not exists')
logger.info(f'{ckpt_path} not exists')

def make_divisible(self, v, divisor=8, min_value=None):
if min_value is None:
Expand Down
5 changes: 4 additions & 1 deletion torchocr/networks/backbones/DetResNetvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import division
from __future__ import print_function

import logging
from collections import OrderedDict
import os
import torch
Expand Down Expand Up @@ -233,10 +234,12 @@ def __init__(self, in_channels, layers, pretrained=True, **kwargs):
self.stages.append(nn.Sequential(*block_list))
if pretrained:
ckpt_path = f'./weights/resnet{layers}_vd.pth'
logger = logging.getLogger('torchocr')
if os.path.exists(ckpt_path):
logger.info('load imagenet weights')
self.load_state_dict(torch.load(ckpt_path))
else:
print(f'{ckpt_path} not exists')
logger.info(f'{ckpt_path} not exists')

def load_3rd_state_dict(self, _3rd_name, _state):
if _3rd_name == 'paddle':
Expand Down
51 changes: 51 additions & 0 deletions torchocr/networks/necks/FPN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
# @Time : 2020/5/21 13:50
# @Author : zhoujun
import torch
from torch import nn
import torch.nn.functional as F


class FPN(nn.Module):
def __init__(self, in_channels, out_channels=256, **kwargs):
super().__init__()
assert len(in_channels) == 4
self.out_channels = out_channels
self.reduce_c5 = nn.Conv2d(in_channels=in_channels[3], out_channels=self.out_channels, kernel_size=1, bias=False)
self.reduce_c4 = nn.Conv2d(in_channels=in_channels[2], out_channels=self.out_channels, kernel_size=1, bias=False)
self.reduce_c3 = nn.Conv2d(in_channels=in_channels[1], out_channels=self.out_channels, kernel_size=1, bias=False)
self.reduce_c2 = nn.Conv2d(in_channels=in_channels[0], out_channels=self.out_channels, kernel_size=1, bias=False)

self.smooth_p5 = nn.Conv2d(in_channels=out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, bias=False)
self.smooth_p4 = nn.Conv2d(in_channels=out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, bias=False)
self.smooth_p3 = nn.Conv2d(in_channels=out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, bias=False)
self.smooth_p2 = nn.Conv2d(in_channels=out_channels, out_channels=self.out_channels // 4, kernel_size=3, padding=1, bias=False)

def forward(self, x):
c2, c3, c4, c5 = x
in5 = self.reduce_c5(c5)
in4 = self.reduce_c4(c4)
in3 = self.reduce_c3(c3)
in2 = self.reduce_c2(c2)

out4 = self._upsample_add(in5, in4)
out3 = self._upsample_add(out4, in3)
out2 = self._upsample_add(out3, in2)

p5 = self.smooth_p5(in5)
p4 = self.smooth_p4(out4)
p3 = self.smooth_p3(out3)
p2 = self.smooth_p2(out2)

out = self._upsample_cat(p2, p3, p4, p5)
return out

def _upsample_add(self, x, y):
return F.interpolate(x, size=y.size()[2:]) + y

def _upsample_cat(self, p2, p3, p4, p5):
h, w = p2.size()[2:]
p3 = F.interpolate(p3, size=(h, w))
p4 = F.interpolate(p4, size=(h, w))
p5 = F.interpolate(p5, size=(h, w))
return torch.cat([p2, p3, p4, p5], dim=1)

0 comments on commit 4629834

Please sign in to comment.