forked from YOUYUANZY/TrainEnv
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
373 additions
and
53 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn import Module, Parameter | ||
from torchsummary import summary | ||
|
||
from nets.mobilefacenet import MobileFaceNet | ||
|
||
|
||
class Arcface_Head(Module): | ||
def __init__(self, embedding_size=128, num_classes=10575, s=64., m=0.5): | ||
super(Arcface_Head, self).__init__() | ||
self.s = s | ||
self.m = m | ||
self.weight = Parameter(torch.FloatTensor(num_classes, embedding_size)) | ||
nn.init.xavier_uniform_(self.weight) | ||
|
||
self.cos_m = math.cos(m) | ||
self.sin_m = math.sin(m) | ||
self.th = math.cos(math.pi - m) | ||
self.mm = math.sin(math.pi - m) * m | ||
|
||
def forward(self, input, label): | ||
cosine = F.linear(input, F.normalize(self.weight)) | ||
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) | ||
phi = cosine * self.cos_m - sine * self.sin_m | ||
phi = torch.where(cosine.float() > self.th, phi.float(), cosine.float() - self.mm) | ||
|
||
one_hot = torch.zeros(cosine.size()).type_as(phi).long() | ||
one_hot.scatter_(1, label.view(-1, 1).long(), 1) | ||
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) | ||
output *= self.s | ||
return output | ||
|
||
|
||
class Arcface(nn.Module): | ||
def __init__(self, num_classes=None, backbone="mobilefacenet", pretrained=False, mode="train"): | ||
super(Arcface, self).__init__() | ||
if backbone == "mobilefacenet": | ||
embedding_size = 128 | ||
s = 32 | ||
self.arcface = MobileFaceNet(embedding_size=embedding_size) | ||
else: | ||
raise ValueError('Unsupported backbone - `{}`.'.format(backbone)) | ||
|
||
self.mode = mode | ||
if mode == "train": | ||
self.head = Arcface_Head(embedding_size=embedding_size, num_classes=num_classes, s=s) | ||
|
||
def forward(self, x, y=None, mode="predict"): | ||
x = self.arcface(x) | ||
x = x.view(x.size()[0], -1) | ||
x = F.normalize(x) | ||
if mode == "predict": | ||
return x | ||
else: | ||
x = self.head(x, y) | ||
return x | ||
|
||
|
||
if __name__ == '__main__': | ||
a = Arcface(mode='predict') | ||
# for name, value in a.named_parameters(): | ||
# print(name) | ||
device = torch.device('cuda:0') | ||
a = a.to(device) | ||
a.cuda() | ||
summary(a, (3, 112, 112)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from torch import nn | ||
from torch.nn import BatchNorm2d, Conv2d, Module, PReLU, Sequential | ||
|
||
|
||
class Flatten(Module): | ||
def forward(self, input): | ||
return input.view(input.size(0), -1) | ||
|
||
|
||
class Linear_block(Module): | ||
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): | ||
super(Linear_block, self).__init__() | ||
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, | ||
bias=False) | ||
self.bn = BatchNorm2d(out_c) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
x = self.bn(x) | ||
return x | ||
|
||
|
||
class Residual_Block(Module): | ||
def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): | ||
super(Residual_Block, self).__init__() | ||
self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) | ||
self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride) | ||
self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) | ||
self.residual = residual | ||
|
||
def forward(self, x): | ||
if self.residual: | ||
short_cut = x | ||
x = self.conv(x) | ||
x = self.conv_dw(x) | ||
x = self.project(x) | ||
if self.residual: | ||
output = short_cut + x | ||
else: | ||
output = x | ||
return output | ||
|
||
|
||
class Residual(Module): | ||
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): | ||
super(Residual, self).__init__() | ||
modules = [] | ||
for _ in range(num_block): | ||
modules.append( | ||
Residual_Block(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups)) | ||
self.model = Sequential(*modules) | ||
|
||
def forward(self, x): | ||
return self.model(x) | ||
|
||
|
||
class Conv_block(Module): | ||
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): | ||
super(Conv_block, self).__init__() | ||
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, | ||
bias=False) | ||
self.bn = BatchNorm2d(out_c) | ||
self.prelu = PReLU(out_c) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
x = self.bn(x) | ||
x = self.prelu(x) | ||
return x | ||
|
||
|
||
class MobileFaceNet(Module): | ||
def __init__(self, embedding_size): | ||
super(MobileFaceNet, self).__init__() | ||
# 112,112,3 -> 56,56,64 | ||
self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1)) | ||
|
||
# 56,56,64 -> 56,56,64 | ||
self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) | ||
|
||
# 56,56,64 -> 28,28,64 | ||
self.conv_23 = Residual_Block(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128) | ||
self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) | ||
|
||
# 28,28,64 -> 14,14,128 | ||
self.conv_34 = Residual_Block(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256) | ||
self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) | ||
|
||
# 14,14,128 -> 7,7,128 | ||
self.conv_45 = Residual_Block(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512) | ||
self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) | ||
|
||
self.sep = nn.Conv2d(128, 512, kernel_size=1, bias=False) | ||
self.sep_bn = nn.BatchNorm2d(512) | ||
self.prelu = nn.PReLU(512) | ||
|
||
self.GDC_dw = nn.Conv2d(512, 512, kernel_size=7, bias=False, groups=512) | ||
self.GDC_bn = nn.BatchNorm2d(512) | ||
|
||
self.features = nn.Conv2d(512, embedding_size, kernel_size=1, bias=False) | ||
self.last_bn = nn.BatchNorm2d(embedding_size) | ||
|
||
self._initialize_weights() | ||
|
||
def _initialize_weights(self): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
if m.bias is not None: | ||
m.bias.data.zero_() | ||
elif isinstance(m, nn.BatchNorm2d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
elif isinstance(m, nn.Linear): | ||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
if m.bias is not None: | ||
m.bias.data.zero_() | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = self.conv2_dw(x) | ||
x = self.conv_23(x) | ||
x = self.conv_3(x) | ||
x = self.conv_34(x) | ||
x = self.conv_4(x) | ||
x = self.conv_45(x) | ||
x = self.conv_5(x) | ||
|
||
x = self.sep(x) | ||
x = self.sep_bn(x) | ||
x = self.prelu(x) | ||
|
||
x = self.GDC_dw(x) | ||
x = self.GDC_bn(x) | ||
|
||
x = self.features(x) | ||
x = self.last_bn(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.