From 38433b29005a251cece26afc16acb06b35ab9a58 Mon Sep 17 00:00:00 2001 From: zhoujun <572459439@qq.com> Date: Mon, 29 Jun 2020 16:16:28 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0mb3=E5=92=8Cresnet50=E9=A2=84?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/det_train_db_config.py | 8 ++++---- torchocr/networks/backbones/DetMobilenetV3.py | 7 ++++++- torchocr/networks/backbones/DetResNetvd.py | 10 ++++++++-- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/config/det_train_db_config.py b/config/det_train_db_config.py index c6639f5..d396902 100644 --- a/config/det_train_db_config.py +++ b/config/det_train_db_config.py @@ -30,7 +30,7 @@ 'resume_from': '', # 继续训练地址 'third_party_name': '', # 加载paddle模型可选 'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint", # 模型保存地址,log文件也保存在这里 - 'device': 'cuda:0', # 不建议修改 + 'device': 'cuda:0', # 不建议修改 'epochs': 1200, 'fine_tune_stage': ['backbone', 'neck', 'head'], 'print_interval': 1, # step为单位 @@ -48,7 +48,7 @@ config.model = { 'type': "DetModel", - 'backbone': {"type": "ResNet", 'layers': 18}, + 'backbone': {"type": "ResNet", 'layers': 18, 'pretrained': True}, # ResNet or MobileNetV3 'neck': {"type": 'FPN', 'out_channels': 256}, 'head': {"type": "DBHead"}, 'in_channels': 3, @@ -64,7 +64,7 @@ 'type': 'DBPostProcess', 'thresh': 0.3, # 二值化输出map的阈值 'box_thresh': 0.7, # 低于此阈值的box丢弃 - 'unclip_ratio': 1.5 # 扩大框的比例 + 'unclip_ratio': 1.5 # 扩大框的比例 } # for dataset @@ -111,7 +111,7 @@ }, 'loader': { 'type': 'DataLoader', - 'batch_size': 1, # 必须为1 + 'batch_size': 1, # 必须为1 'shuffle': False, 'num_workers': 1, 'collate_fn': { diff --git a/torchocr/networks/backbones/DetMobilenetV3.py b/torchocr/networks/backbones/DetMobilenetV3.py index 47ae69a..9964698 100644 --- a/torchocr/networks/backbones/DetMobilenetV3.py +++ b/torchocr/networks/backbones/DetMobilenetV3.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import torch from torch import nn from torchocr.networks.CommonModules import ConvBNACT, SEBlock @@ -46,7 +47,7 @@ def forward(self, x): class MobileNetV3(nn.Module): - def __init__(self, in_channels, **kwargs): + def __init__(self, in_channels, pretrained=True, **kwargs): """ the MobilenetV3 backbone network for detection module. Args: @@ -145,6 +146,10 @@ def __init__(self, in_channels, **kwargs): act='hard_swish') self.out_channels.append(self.make_divisible(scale * cls_ch_squeeze)) + if pretrained: + ckpt_path = './weights/mb3_imagenet.pth' + self.load_state_dict(torch.load(ckpt_path)) + def make_divisible(self, v, divisor=8, min_value=None): if min_value is None: min_value = divisor diff --git a/torchocr/networks/backbones/DetResNetvd.py b/torchocr/networks/backbones/DetResNetvd.py index 8c9b18d..caf77d3 100644 --- a/torchocr/networks/backbones/DetResNetvd.py +++ b/torchocr/networks/backbones/DetResNetvd.py @@ -180,7 +180,7 @@ def forward(self, x): class ResNet(nn.Module): - def __init__(self, in_channels, layers, **kwargs): + def __init__(self, in_channels, layers, pretrained=True, **kwargs): """ the Resnet backbone network for detection module. Args: @@ -230,6 +230,12 @@ def __init__(self, in_channels, layers, **kwargs): in_ch = block_list[-1].output_channels self.out_channels.append(in_ch) self.stages.append(nn.Sequential(*block_list)) + if pretrained: + if layers == 50: + ckpt_path = './weights/resnet50_vd_imagenet.pth' + self.load_state_dict(torch.load(ckpt_path)) + else: + print('pretrained weight only support resnet50 now') def load_3rd_state_dict(self, _3rd_name, _state): if _3rd_name == 'paddle': @@ -248,4 +254,4 @@ def forward(self, x): for stage in self.stages: x = stage(x) out.append(x) - return out \ No newline at end of file + return out