Skip to content

Commit

Permalink
swing
Browse files Browse the repository at this point in the history
  • Loading branch information
wangjun492 committed Nov 8, 2021
1 parent 34dec6d commit 2f97a0e
Show file tree
Hide file tree
Showing 9 changed files with 1,040 additions and 18 deletions.
593 changes: 593 additions & 0 deletions backbone/Swin_Transformer.py

Large diffs are not rendered by default.

40 changes: 28 additions & 12 deletions backbone/backbone_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from backbone.ReXNets import ReXNetV1
from backbone.LightCNN import LightCNN
from backbone.RepVGG import RepVGG
from backbone.Swin_Transformer import SwinTransformer

class BackboneFactory:
"""Factory to produce backbone according the backbone_conf.yaml.
Expand All @@ -30,7 +31,7 @@ class BackboneFactory:
def __init__(self, backbone_type, backbone_conf_file):
self.backbone_type = backbone_type
with open(backbone_conf_file) as f:
backbone_conf = yaml.load(f)
backbone_conf = yaml.load(f, Loader=yaml.FullLoader)
self.backbone_param = backbone_conf[backbone_type]
print('backbone param:')
print(self.backbone_param)
Expand Down Expand Up @@ -82,17 +83,6 @@ def get_backbone(self):
backbone = ResidualAttentionNet(
stage1_modules, stage2_modules, stage3_modules,
feat_dim, out_h, out_w)
elif self.backbone_type == 'AttentionNet_wj':
stage1_modules = self.backbone_param['stage1_modules'] # the number of attention modules in stage1.
stage2_modules = self.backbone_param['stage2_modules'] # the number of attention modules in stage2.
stage3_modules = self.backbone_param['stage3_modules'] # the number of attention modules in stage3.
image_size = self.backbone_param['image_size'] # input image size, e.g. 112.
feat_dim = self.backbone_param['feat_dim'] # dimension of the output features, e.g. 512.
out_h = self.backbone_param['out_h'] # height of the feature map before the final features.
out_w = self.backbone_param['out_w'] # width of the feature map before the final features.
backbone = AttentionNet(
stage1_modules, stage2_modules, stage3_modules,
image_size, feat_dim, out_h, out_w)
elif self.backbone_type == 'TF-NAS':
drop_ratio = self.backbone_param['drop_ratio'] # drop out ratio.
out_h = self.backbone_param['out_h'] # height of the feature map before the final features.
Expand Down Expand Up @@ -141,6 +131,32 @@ def get_backbone(self):
backbone = RepVGG([blocks1, blocks2, blocks3, blocks4],
[width1, width2, width3, width4],
feat_dim, out_h, out_w)
elif self.backbone_type == 'SwinTransformer':
img_size = self.backbone_param['img_size']
patch_size= self.backbone_param['patch_size']
in_chans = self.backbone_param['in_chans']
embed_dim = self.backbone_param['embed_dim']
depths = self.backbone_param['depths']
num_heads = self.backbone_param['num_heads']
window_size = self.backbone_param['window_size']
mlp_ratio = self.backbone_param['mlp_ratio']
drop_rate = self.backbone_param['drop_rate']
drop_path_rate = self.backbone_param['drop_path_rate']
backbone = SwinTransformer(img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=True,
qk_scale=None,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
ape=False,
patch_norm=True,
use_checkpoint=False)
else:
pass
return backbone
2 changes: 1 addition & 1 deletion head/head_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class HeadFactory:
def __init__(self, head_type, head_conf_file):
self.head_type = head_type
with open(head_conf_file) as f:
head_conf = yaml.load(f)
head_conf = yaml.load(f, Loader=yaml.FullLoader)
self.head_param = head_conf[head_type]
print('head param:')
print(self.head_param)
Expand Down
35 changes: 30 additions & 5 deletions test_protocol/backbone_conf.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
MobileFaceNet:
feat_dim: 512
#out_h: 4
out_h: 7
out_w: 7

Expand All @@ -15,7 +14,7 @@ ResNet:
EfficientNet:
width: 1.0
depth: 1.0
image_size: 110
image_size: 112
drop_ratio: 0.2
out_h: 7
out_w: 7
Expand All @@ -26,7 +25,7 @@ HRNet:
out_h: 7
out_w: 7
feat_dim: 512
IMAGE_SIZE:
IMAGE_SIZE:
- 112
- 112
EXTRA:
Expand Down Expand Up @@ -88,8 +87,8 @@ GhostNet:

AttentionNet:
stage1_modules: 1
stage2_modules: 2
stage3_modules: 3
stage2_modules: 1
stage3_modules: 1
feat_dim: 512
out_h: 7
out_w: 7
Expand Down Expand Up @@ -119,6 +118,13 @@ ReXNet:
feat_dim: 512
dropout_ratio: 0.2

LightCNN:
depth: 29
out_h: 7
out_w: 7
feat_dim: 512
dropout_ratio: 0.2

RepVGG:
blocks1: 4
blocks2: 6
Expand All @@ -131,3 +137,22 @@ RepVGG:
out_h: 7
out_w: 7
feat_dim: 512
SwinTransformer:
img_size: 224
patch_size: 4
in_chans: 3
embed_dim: 96
depths:
- 2
- 2
- 18
- 2
num_heads:
- 3
- 6
- 12
- 24
window_size: 7
mlp_ratio: 4.0
drop_rate: 0.0
drop_path_rate: 0.3
19 changes: 19 additions & 0 deletions training_mode/backbone_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,22 @@ RepVGG:
out_h: 7
out_w: 7
feat_dim: 512
SwinTransformer:
img_size: 224
patch_size: 4
in_chans: 3
embed_dim: 96
depths:
- 2
- 2
- 18
- 2
num_heads:
- 3
- 6
- 12
- 24
window_size: 7
mlp_ratio: 4.0
drop_rate: 0.0
drop_path_rate: 0.3
102 changes: 102 additions & 0 deletions training_mode/swin_training/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------

import torch
from timm.scheduler.cosine_lr import CosineLRScheduler
from timm.scheduler.step_lr import StepLRScheduler
from timm.scheduler.scheduler import Scheduler


def build_scheduler(optimizer, n_iter_per_epoch, epoches, warm_up_epoches):
num_steps = int(epoches * n_iter_per_epoch)
warmup_steps = int(warm_up_epoches * n_iter_per_epoch)
lr_scheduler = None
NAME = 'cosine'
if NAME == 'cosine':
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=num_steps,
t_mul=1.,
lr_min=5.0e-06,
warmup_lr_init=5.0e-07,
warmup_t=warmup_steps,
cycle_limit=1,
t_in_epochs=False,
)
elif NAME == 'linear':
lr_scheduler = LinearLRScheduler(
optimizer,
t_initial=num_steps,
lr_min_rate=0.01,
warmup_lr_init=config.TRAIN.WARMUP_LR,
warmup_t=warmup_steps,
t_in_epochs=False,
)
elif NAME == 'step':
decay_steps = int(2 * n_iter_per_epoch)
lr_scheduler = StepLRScheduler(
optimizer,
decay_t=decay_steps,
decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
warmup_lr_init=config.TRAIN.WARMUP_LR,
warmup_t=warmup_steps,
t_in_epochs=False,
)

return lr_scheduler


class LinearLRScheduler(Scheduler):
def __init__(self,
optimizer: torch.optim.Optimizer,
t_initial: int,
lr_min_rate: float,
warmup_t=0,
warmup_lr_init=0.,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True,
) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)

self.t_initial = t_initial
self.lr_min_rate = lr_min_rate
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]

def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
t = t - self.warmup_t
total_t = self.t_initial - self.warmup_t
lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
return lrs

def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None

def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
57 changes: 57 additions & 0 deletions training_mode/swin_training/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------

from torch import optim as optim


def build_optimizer(model, lr, weight_decay=0.05):
"""
Build optimizer, set weight decay of normalization to 0 by default.
"""
skip = {}
skip_keywords = {}
if hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay()
if hasattr(model, 'no_weight_decay_keywords'):
skip_keywords = model.no_weight_decay_keywords()
parameters = set_weight_decay(model, skip, skip_keywords)

opt_lower = 'adamw'
optimizer = None
if opt_lower == 'sgd':
optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
elif opt_lower == 'adamw':
optimizer = optim.AdamW(parameters, eps=1.0e-08, betas=[0.9, 0.999],
lr=lr, weight_decay=weight_decay)

return optimizer


def set_weight_decay(model, skip_list=(), skip_keywords=()):
has_decay = []
no_decay = []

for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
check_keywords_in_name(name, skip_keywords):
no_decay.append(param)
# print(f"{name} has no weight decay")
else:
has_decay.append(param)
return [{'params': has_decay},
{'params': no_decay, 'weight_decay': 0.}]


def check_keywords_in_name(name, keywords=()):
isin = False
for keyword in keywords:
if keyword in name:
isin = True
return isin
Loading

0 comments on commit 2f97a0e

Please sign in to comment.