Skip to content

Commit

Permalink
添加mb3和resnet50预训练模型
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Jun 29, 2020
1 parent 45cb17e commit 38433b2
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
8 changes: 4 additions & 4 deletions config/det_train_db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
'resume_from': '', # 继续训练地址
'third_party_name': '', # 加载paddle模型可选
'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint", # 模型保存地址,log文件也保存在这里
'device': 'cuda:0', # 不建议修改
'device': 'cuda:0', # 不建议修改
'epochs': 1200,
'fine_tune_stage': ['backbone', 'neck', 'head'],
'print_interval': 1, # step为单位
Expand All @@ -48,7 +48,7 @@

config.model = {
'type': "DetModel",
'backbone': {"type": "ResNet", 'layers': 18},
'backbone': {"type": "ResNet", 'layers': 18, 'pretrained': True}, # ResNet or MobileNetV3
'neck': {"type": 'FPN', 'out_channels': 256},
'head': {"type": "DBHead"},
'in_channels': 3,
Expand All @@ -64,7 +64,7 @@
'type': 'DBPostProcess',
'thresh': 0.3, # 二值化输出map的阈值
'box_thresh': 0.7, # 低于此阈值的box丢弃
'unclip_ratio': 1.5 # 扩大框的比例
'unclip_ratio': 1.5 # 扩大框的比例
}

# for dataset
Expand Down Expand Up @@ -111,7 +111,7 @@
},
'loader': {
'type': 'DataLoader',
'batch_size': 1, # 必须为1
'batch_size': 1, # 必须为1
'shuffle': False,
'num_workers': 1,
'collate_fn': {
Expand Down
7 changes: 6 additions & 1 deletion torchocr/networks/backbones/DetMobilenetV3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
from torch import nn
from torchocr.networks.CommonModules import ConvBNACT, SEBlock

Expand Down Expand Up @@ -46,7 +47,7 @@ def forward(self, x):


class MobileNetV3(nn.Module):
def __init__(self, in_channels, **kwargs):
def __init__(self, in_channels, pretrained=True, **kwargs):
"""
the MobilenetV3 backbone network for detection module.
Args:
Expand Down Expand Up @@ -145,6 +146,10 @@ def __init__(self, in_channels, **kwargs):
act='hard_swish')
self.out_channels.append(self.make_divisible(scale * cls_ch_squeeze))

if pretrained:
ckpt_path = './weights/mb3_imagenet.pth'
self.load_state_dict(torch.load(ckpt_path))

def make_divisible(self, v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
Expand Down
10 changes: 8 additions & 2 deletions torchocr/networks/backbones/DetResNetvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def forward(self, x):


class ResNet(nn.Module):
def __init__(self, in_channels, layers, **kwargs):
def __init__(self, in_channels, layers, pretrained=True, **kwargs):
"""
the Resnet backbone network for detection module.
Args:
Expand Down Expand Up @@ -230,6 +230,12 @@ def __init__(self, in_channels, layers, **kwargs):
in_ch = block_list[-1].output_channels
self.out_channels.append(in_ch)
self.stages.append(nn.Sequential(*block_list))
if pretrained:
if layers == 50:
ckpt_path = './weights/resnet50_vd_imagenet.pth'
self.load_state_dict(torch.load(ckpt_path))
else:
print('pretrained weight only support resnet50 now')

def load_3rd_state_dict(self, _3rd_name, _state):
if _3rd_name == 'paddle':
Expand All @@ -248,4 +254,4 @@ def forward(self, x):
for stage in self.stages:
x = stage(x)
out.append(x)
return out
return out

0 comments on commit 38433b2

Please sign in to comment.