Skip to content

Commit

Permalink
新增arcface和mobilefacenet,但是没适配注意力机制
Browse files Browse the repository at this point in the history
  • Loading branch information
YOUYUANZY committed Jan 27, 2024
1 parent e83e074 commit 775912a
Show file tree
Hide file tree
Showing 11 changed files with 373 additions and 53 deletions.
3 changes: 3 additions & 0 deletions .idea/TrainEnv.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/other.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ train:
cuda: true #GPU加速
fp16: true #是否使用混合精度(可减少显存,需pytorch1.7.1以上)
dataPath: train_data.txt #数据及标签路径
inputSize: [224,224,3] #输入图像大小
backbone: mobilenet #主干特征提取网络 mobilenet、
inputSize: [224,224,3] #输入图像大小 facenet[224,224,3]、arcface[112,112,3]
model: facenet #识别模型 facenet、arcface
backbone: mobilenet #主干特征提取网络 mobilenet(facenet)、mobilefacenet(arcface)
attention: AFNB #注意力机制 CBAM、APNB、AFNB、GCNet、SE、scSE、Triplet
onlyAttention: true #只训练注意力机制部分
onlyAttention: false #只训练注意力机制部分
weightPath: '' #模型权重路径
preTrained: false #是否预训练
batchSize: 30 #批次大小(faceNet需为3的倍数)
Expand Down
Binary file added lfwEvalInfo/arcface_mobilefacenet.pth
Binary file not shown.
70 changes: 70 additions & 0 deletions nets/arcface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module, Parameter
from torchsummary import summary

from nets.mobilefacenet import MobileFaceNet


class Arcface_Head(Module):
def __init__(self, embedding_size=128, num_classes=10575, s=64., m=0.5):
super(Arcface_Head, self).__init__()
self.s = s
self.m = m
self.weight = Parameter(torch.FloatTensor(num_classes, embedding_size))
nn.init.xavier_uniform_(self.weight)

self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m

def forward(self, input, label):
cosine = F.linear(input, F.normalize(self.weight))
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
phi = cosine * self.cos_m - sine * self.sin_m
phi = torch.where(cosine.float() > self.th, phi.float(), cosine.float() - self.mm)

one_hot = torch.zeros(cosine.size()).type_as(phi).long()
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output *= self.s
return output


class Arcface(nn.Module):
def __init__(self, num_classes=None, backbone="mobilefacenet", pretrained=False, mode="train"):
super(Arcface, self).__init__()
if backbone == "mobilefacenet":
embedding_size = 128
s = 32
self.arcface = MobileFaceNet(embedding_size=embedding_size)
else:
raise ValueError('Unsupported backbone - `{}`.'.format(backbone))

self.mode = mode
if mode == "train":
self.head = Arcface_Head(embedding_size=embedding_size, num_classes=num_classes, s=s)

def forward(self, x, y=None, mode="predict"):
x = self.arcface(x)
x = x.view(x.size()[0], -1)
x = F.normalize(x)
if mode == "predict":
return x
else:
x = self.head(x, y)
return x


if __name__ == '__main__':
a = Arcface(mode='predict')
# for name, value in a.named_parameters():
# print(name)
device = torch.device('cuda:0')
a = a.to(device)
a.cuda()
summary(a, (3, 112, 112))
14 changes: 7 additions & 7 deletions nets/facenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ def forward(self, x, mode="predict"):


if __name__ == '__main__':
a = Facenet(mode='predict', attention="Triplet")
for name, value in a.named_parameters():
print(name)
# device = torch.device('cuda:0')
# a = a.to(device)
# a.cuda()
# summary(a, (3, 224, 224))
a = Facenet(mode='predict', attention="AFNB")
# for name, value in a.named_parameters():
# print(name)
device = torch.device('cuda:0')
a = a.to(device)
a.cuda()
summary(a, (3, 112, 112))
138 changes: 138 additions & 0 deletions nets/mobilefacenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from torch import nn
from torch.nn import BatchNorm2d, Conv2d, Module, PReLU, Sequential


class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)


class Linear_block(Module):
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
super(Linear_block, self).__init__()
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding,
bias=False)
self.bn = BatchNorm2d(out_c)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x


class Residual_Block(Module):
def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
super(Residual_Block, self).__init__()
self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
self.residual = residual

def forward(self, x):
if self.residual:
short_cut = x
x = self.conv(x)
x = self.conv_dw(x)
x = self.project(x)
if self.residual:
output = short_cut + x
else:
output = x
return output


class Residual(Module):
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
super(Residual, self).__init__()
modules = []
for _ in range(num_block):
modules.append(
Residual_Block(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
self.model = Sequential(*modules)

def forward(self, x):
return self.model(x)


class Conv_block(Module):
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
super(Conv_block, self).__init__()
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding,
bias=False)
self.bn = BatchNorm2d(out_c)
self.prelu = PReLU(out_c)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.prelu(x)
return x


class MobileFaceNet(Module):
def __init__(self, embedding_size):
super(MobileFaceNet, self).__init__()
# 112,112,3 -> 56,56,64
self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))

# 56,56,64 -> 56,56,64
self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)

# 56,56,64 -> 28,28,64
self.conv_23 = Residual_Block(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))

# 28,28,64 -> 14,14,128
self.conv_34 = Residual_Block(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))

# 14,14,128 -> 7,7,128
self.conv_45 = Residual_Block(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))

self.sep = nn.Conv2d(128, 512, kernel_size=1, bias=False)
self.sep_bn = nn.BatchNorm2d(512)
self.prelu = nn.PReLU(512)

self.GDC_dw = nn.Conv2d(512, 512, kernel_size=7, bias=False, groups=512)
self.GDC_bn = nn.BatchNorm2d(512)

self.features = nn.Conv2d(512, embedding_size, kernel_size=1, bias=False)
self.last_bn = nn.BatchNorm2d(embedding_size)

self._initialize_weights()

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()

def forward(self, x):
x = self.conv1(x)
x = self.conv2_dw(x)
x = self.conv_23(x)
x = self.conv_3(x)
x = self.conv_34(x)
x = self.conv_4(x)
x = self.conv_45(x)
x = self.conv_5(x)

x = self.sep(x)
x = self.sep_bn(x)
x = self.prelu(x)

x = self.GDC_dw(x)
x = self.GDC_bn(x)

x = self.features(x)
x = self.last_bn(x)
return x
42 changes: 30 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,34 @@
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim

from nets.arcface import Arcface
from nets.facenet import Facenet
from torch.utils.data import DataLoader

from utils.lossRecord import LossHistory
from utils.dataloader import FacenetDataset, LFWDataset, dataset_collate
from utils.dataloader import FacenetDataset, LFWDataset, face_dataset_collate, arcFaceDataset, arc_dataset_collate
from utils.training import get_Lr_Fun, set_lr, triplet_loss
from utils.utils import get_num_classes
from utils.utils import get_num_classes, seed_everything
from utils.epochTrain import epochTrain


def train(config, lfw):
seed_everything(11)
# 获取训练设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 一个标记(既负责提示信息又代表设备序号)
flag = 0
# 获取标签数量
num_classes = get_num_classes(config.dataPath)
# 加载模型
model = Facenet(backbone=config.backbone, attention=config.attention, num_classes=num_classes, pretrained=config.preTrained)
if config.model == 'facenet':
model = Facenet(backbone=config.backbone, attention=config.attention, num_classes=num_classes,
pretrained=config.preTrained)
elif config.model == 'arcface':
model = Arcface(num_classes=num_classes, backbone=config.backbone, pretrained=config.preTrained)
else:
raise ValueError('model unsupported')
# 加载权重
if config.weightPath != '':
if flag == 0:
Expand Down Expand Up @@ -103,25 +112,34 @@ def train(config, lfw):
if epoch_step == 0 or epoch_step_val == 0:
raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")
# 构建数据集加载器
train_dataset = FacenetDataset(config.inputSize, lines[:num_train], num_classes, random=True)
val_dataset = FacenetDataset(config.inputSize, lines[num_train:], num_classes, random=False)
if config.model == 'facenet':
train_dataset = FacenetDataset(config.inputSize, lines[:num_train], num_classes, random=True)
val_dataset = FacenetDataset(config.inputSize, lines[num_train:], num_classes, random=False)
elif config.model == 'arcface':
train_dataset = arcFaceDataset(config.inputSize, lines[:num_train], random=True)
val_dataset = arcFaceDataset(config.inputSize, lines[num_train:], random=False)
else:
raise ValueError('dataset unsupported')
# 获得训练和验证数据集
train_sampler = None
val_sampler = None
shuffle = True
gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=config.batchSize // 3,
batchSize = config.batchSize // 3 if config.model == 'facenet' else config.batchSize
collate_fn = face_dataset_collate if config.model == 'facenet' else arc_dataset_collate
gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batchSize,
num_workers=config.numWorkers,
pin_memory=False,
drop_last=True, collate_fn=dataset_collate, sampler=train_sampler)
gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=config.batchSize // 3,
pin_memory=True,
drop_last=True, collate_fn=collate_fn, sampler=train_sampler)
gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batchSize,
num_workers=config.numWorkers,
pin_memory=False,
drop_last=True, collate_fn=dataset_collate, sampler=val_sampler)
pin_memory=True,
drop_last=True, collate_fn=collate_fn, sampler=val_sampler)
# 开始训练

for epoch in range(config.startEpoch, config.endEpoch):
set_lr(optimizer, lr_func, epoch)
epochTrain(model_train, model, loss_history, loss, optimizer, epoch, epoch_step, epoch_step_val, gen,
epochTrain(config.model, model_train, model, loss_history, loss, optimizer, epoch, epoch_step, epoch_step_val,
gen,
gen_val, config.endEpoch, config.cuda, LFW_loader, config.batchSize // 3, config.lfwEval,
config.fp16, scaler, config.savePeriod, config.saveDir, flag)
# 训练结束
Expand Down
Loading

0 comments on commit 775912a

Please sign in to comment.