forked from modelscope/facechain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Delete face_module/TransFace/README.md * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Add files via upload
- Loading branch information
1 parent
3ad2494
commit 0ce8353
Showing
42 changed files
with
5,520 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import cv2 | ||
import numpy as np | ||
from math import sqrt | ||
|
||
def amplitude_spectrum_mix(img1, img2, alpha, ratio=1.0): #img_src, img_random | ||
"""Input image size: ndarray of [H, W, C]""" | ||
lam = np.random.uniform(0, alpha) | ||
assert img1.shape == img2.shape | ||
h, w, c = img1.shape | ||
h_crop = int(h * sqrt(ratio)) | ||
w_crop = int(w * sqrt(ratio)) | ||
h_start = h // 2 - h_crop // 2 | ||
w_start = w // 2 - w_crop // 2 | ||
|
||
img1_fft = np.fft.fft2(img1, axes=(0, 1)) | ||
img2_fft = np.fft.fft2(img2, axes=(0, 1)) | ||
img1_abs, img1_pha = np.abs(img1_fft), np.angle(img1_fft) | ||
img2_abs, img2_pha = np.abs(img2_fft), np.angle(img2_fft) | ||
|
||
img1_abs = np.fft.fftshift(img1_abs, axes=(0, 1)) | ||
img2_abs = np.fft.fftshift(img2_abs, axes=(0, 1)) | ||
|
||
img1_abs_ = np.copy(img1_abs) | ||
img2_abs_ = np.copy(img2_abs) | ||
|
||
img1_abs[h_start:h_start + h_crop, w_start:w_start + w_crop] = \ | ||
lam * img2_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] + (1 - lam) * img1_abs_[ | ||
h_start:h_start + h_crop, | ||
w_start:w_start + w_crop] | ||
|
||
img1_abs = np.fft.ifftshift(img1_abs, axes=(0, 1)) | ||
img2_abs = np.fft.ifftshift(img2_abs, axes=(0, 1)) | ||
|
||
img_src_random = img1_abs * (np.e ** (1j * img1_pha)) | ||
img_src_random = np.real(np.fft.ifft2(img_src_random, axes=(0, 1))) | ||
img_src_random = np.uint8(np.clip(img_src_random, 0, 255)) | ||
return img_src_random | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,72 @@ | ||
### The code will be released in the future. | ||
# TransFace: Calibrating Transformer Training for Face Recognition from a Data-Centric Perspective (ICCV-2023) | ||
This is the official PyTorch implementation of [TransFace](https://openaccess.thecvf.com/content/ICCV2023/html/Dan_TransFace_Calibrating_Transformer_Training_for_Face_Recognition_from_a_Data-Centric_ICCV_2023_paper.html). | ||
|
||
[Arxiv](https://arxiv.org/abs/2308.10133) | ||
|
||
## ModelScope | ||
You can quickly experience and invoke our TransFace model on the [ModelScope](https://modelscope.cn/models/damo/cv_vit_face-recognition/summary). | ||
|
||
## Requirements | ||
* Install Pytorch (torch>=1.9.0) | ||
* ```pip install -r requirement.txt``` | ||
|
||
## Datasets | ||
You can download the training datasets, including MS1MV2 and Glint360K: | ||
* MS1MV2: [Google Drive](https://drive.google.com/file/d/1SXS4-Am3bsKSK615qbYdbA_FMVh3sAvR/view) | ||
* Glint360K: [Baidu](https://pan.baidu.com/share/init?surl=GsYqTTt7_Dn8BfxxsLFN0w) (code=:o3az) | ||
|
||
You can download the test dataset IJB-C as follows: | ||
* IJB-C: [Google Drive](https://drive.google.com/file/d/1aC4zf2Bn0xCVH_ZtEuQipR2JvRb1bf8o/view) | ||
|
||
## How to Train Models | ||
1. You need to modify the path of training data in every configuration file in folder configs. | ||
|
||
2. To run on a machine with 8 GPUs: | ||
``` | ||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=12581 train.py | ||
``` | ||
|
||
## How to Test Models | ||
1. You need to modify the path of IJB-C dataset in eval_ijbc.py. | ||
|
||
2. Run: | ||
``` | ||
python eval_ijbc.py --model-prefix work_dirs/glint360k_vit_s/model.pt --result-dir work_dirs/glint360k_vit_s --network vit_s_dp005_mask_0 > ijbc_glint360k_vit_s.log 2>&1 & | ||
``` | ||
|
||
## TransFace Pretrained Models | ||
|
||
You can download the TransFace models reported in our paper as follows: | ||
|
||
| Training Data | Model | IJB-C(1e-6) | IJB-C(1e-5) | IJB-C(1e-4) | IJB-C(1e-3) | IJB-C(1e-2) | IJB-C(1e-1) | | ||
| ------ | ------ | ------ | ------ | ------ | ------ | ------ | ------ | | ||
| MS1MV2 | [TransFace-S](https://drive.google.com/file/d/1UZWCg7jNESDv8EWs7mxQSswCMGbAZNF4/view?usp=share_link) | 86.75 | 93.87 | 96.45 | 97.51 | 98.34 | 98.99 | | ||
| MS1MV2 | [TransFace-B](https://drive.google.com/file/d/16O-q30mH8d3lECqa5eJd8rABaUlNhQ0K/view?usp=share_link) | 86.73 | 94.15 | 96.55 | 97.73 | 98.47 | 99.11 | | ||
| MS1MV2 | [TransFace-L](https://drive.google.com/file/d/1uXUFT6ujEPqvCTHzONsp6-DMIc24Cc85/view?usp=share_link) | 86.90 | 94.55 | 96.59 | 97.80 | 98.45 | 99.04 | | ||
|
||
| Training Data | Model | IJB-C(1e-6) | IJB-C(1e-5) | IJB-C(1e-4) | IJB-C(1e-3) | IJB-C(1e-2) | IJB-C(1e-1) | | ||
| ------ | ------ | ------ | ------ | ------ | ------ | ------ | ------ | | ||
| Glint360K | [TransFace-S](https://drive.google.com/file/d/18Zh_zMlYttKVIGArmDYNEchIvUSH5FQ1/view?usp=share_link) | 89.93 | 96.06 | 97.33 | 98.00 | 98.49 | 99.11 | | ||
| Glint360K | [TransFace-B](https://drive.google.com/file/d/13IezvOo5GvtGVsRap2s5RVqtIl1y0ke5/view?usp=share_link) | 88.64 | 96.18 | 97.45 | 98.17 | 98.66 | 99.23 | | ||
| Glint360K | [TransFace-L](https://drive.google.com/file/d/1jXL_tidh9KqAS6MgeinIk2UNWmEaxfb0/view?usp=share_link) | 89.71 | 96.29 | 97.61 | 98.26 | 98.64 | 99.19 | | ||
|
||
You can test the accuracy of these model: (e.g. Glint360K TransFace-L) | ||
``` | ||
python eval_ijbc.py --model-prefix work_dirs/glint360k_vit_l/glint360k_model_TransFace_L.pt --result-dir work_dirs/glint360k_vit_l --network vit_l_dp005_mask_005 > ijbc_glint360k_vit_l.log 2>&1 & | ||
``` | ||
|
||
## Citation | ||
* If you find it helpful for you, please cite our paper | ||
``` | ||
@inproceedings{dan2023transface, | ||
title={TransFace: Calibrating Transformer Training for Face Recognition from a Data-Centric Perspective}, | ||
author={Dan, Jun and Liu, Yang and Xie, Haoyu and Deng, Jiankang and Xie, Haoran and Xie, Xuansong and Sun, Baigui}, | ||
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, | ||
pages={20642--20653}, | ||
year={2023} | ||
} | ||
``` | ||
|
||
## Acknowledgments | ||
We thank Insighface for the excellent [code base](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 | ||
from .mobilefacenet import get_mbf | ||
|
||
|
||
def get_model(name, **kwargs): | ||
# resnet | ||
if name == "r18": | ||
return iresnet18(False, **kwargs) | ||
elif name == "r34": | ||
return iresnet34(False, **kwargs) | ||
elif name == "r50": | ||
return iresnet50(False, **kwargs) | ||
elif name == "r100": | ||
return iresnet100(False, **kwargs) | ||
elif name == "r200": | ||
return iresnet200(False, **kwargs) | ||
elif name == "r2060": | ||
from .iresnet2060 import iresnet2060 | ||
return iresnet2060(False, **kwargs) | ||
|
||
elif name == "mbf": | ||
fp16 = kwargs.get("fp16", False) | ||
num_features = kwargs.get("num_features", 512) | ||
return get_mbf(fp16=fp16, num_features=num_features) | ||
|
||
elif name == "mbf_large": | ||
from .mobilefacenet import get_mbf_large | ||
fp16 = kwargs.get("fp16", False) | ||
num_features = kwargs.get("num_features", 512) | ||
return get_mbf_large(fp16=fp16, num_features=num_features) | ||
|
||
elif name == "vit_t": | ||
num_features = kwargs.get("num_features", 512) | ||
from .vit import VisionTransformer | ||
return VisionTransformer( | ||
img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12, | ||
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1) | ||
|
||
elif name == "vit_t_dp005_mask0": # For WebFace42M | ||
num_features = kwargs.get("num_features", 512) | ||
from .vit import VisionTransformer | ||
return VisionTransformer( | ||
img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12, | ||
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0) | ||
|
||
elif name == "vit_s": | ||
num_features = kwargs.get("num_features", 512) | ||
from .vit import VisionTransformer | ||
return VisionTransformer( | ||
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12, | ||
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1) | ||
|
||
elif name == "vit_s_dp005_mask_0": # For WebFace42M | ||
num_features = kwargs.get("num_features", 512) | ||
from .vit import VisionTransformer | ||
return VisionTransformer( | ||
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12, | ||
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0) | ||
|
||
elif name == "vit_b": | ||
# this is a feature | ||
num_features = kwargs.get("num_features", 512) | ||
from .vit import VisionTransformer | ||
return VisionTransformer( | ||
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, | ||
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True) | ||
|
||
elif name == "vit_b_dp005_mask_005": # For WebFace42M | ||
# this is a feature | ||
num_features = kwargs.get("num_features", 512) | ||
from .vit import VisionTransformer | ||
return VisionTransformer( | ||
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, | ||
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True) | ||
|
||
elif name == "vit_l_dp005_mask_005": # For WebFace42M | ||
# this is a feature | ||
num_features = kwargs.get("num_features", 512) | ||
from .vit import VisionTransformer | ||
return VisionTransformer( | ||
img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24, | ||
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True) | ||
|
||
else: | ||
raise ValueError() | ||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
import torch | ||
from torch import nn | ||
from torch.utils.checkpoint import checkpoint | ||
|
||
__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] | ||
using_ckpt = False | ||
|
||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | ||
"""3x3 convolution with padding""" | ||
return nn.Conv2d(in_planes, | ||
out_planes, | ||
kernel_size=3, | ||
stride=stride, | ||
padding=dilation, | ||
groups=groups, | ||
bias=False, | ||
dilation=dilation) | ||
|
||
|
||
def conv1x1(in_planes, out_planes, stride=1): | ||
"""1x1 convolution""" | ||
return nn.Conv2d(in_planes, | ||
out_planes, | ||
kernel_size=1, | ||
stride=stride, | ||
bias=False) | ||
|
||
|
||
class IBasicBlock(nn.Module): | ||
expansion = 1 | ||
def __init__(self, inplanes, planes, stride=1, downsample=None, | ||
groups=1, base_width=64, dilation=1): | ||
super(IBasicBlock, self).__init__() | ||
if groups != 1 or base_width != 64: | ||
raise ValueError('BasicBlock only supports groups=1 and base_width=64') | ||
if dilation > 1: | ||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | ||
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) | ||
self.conv1 = conv3x3(inplanes, planes) | ||
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) | ||
self.prelu = nn.PReLU(planes) | ||
self.conv2 = conv3x3(planes, planes, stride) | ||
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
def forward_impl(self, x): | ||
identity = x | ||
out = self.bn1(x) | ||
out = self.conv1(out) | ||
out = self.bn2(out) | ||
out = self.prelu(out) | ||
out = self.conv2(out) | ||
out = self.bn3(out) | ||
if self.downsample is not None: | ||
identity = self.downsample(x) | ||
out += identity | ||
return out | ||
|
||
def forward(self, x): | ||
if self.training and using_ckpt: | ||
return checkpoint(self.forward_impl, x) | ||
else: | ||
return self.forward_impl(x) | ||
|
||
|
||
class IResNet(nn.Module): | ||
fc_scale = 7 * 7 | ||
def __init__(self, | ||
block, layers, dropout=0, num_features=512, zero_init_residual=False, | ||
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): | ||
super(IResNet, self).__init__() | ||
self.extra_gflops = 0.0 | ||
self.fp16 = fp16 | ||
self.inplanes = 64 | ||
self.dilation = 1 | ||
if replace_stride_with_dilation is None: | ||
replace_stride_with_dilation = [False, False, False] | ||
if len(replace_stride_with_dilation) != 3: | ||
raise ValueError("replace_stride_with_dilation should be None " | ||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)) | ||
self.groups = groups | ||
self.base_width = width_per_group | ||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) | ||
self.prelu = nn.PReLU(self.inplanes) | ||
self.layer1 = self._make_layer(block, 64, layers[0], stride=2) | ||
self.layer2 = self._make_layer(block, | ||
128, | ||
layers[1], | ||
stride=2, | ||
dilate=replace_stride_with_dilation[0]) | ||
self.layer3 = self._make_layer(block, | ||
256, | ||
layers[2], | ||
stride=2, | ||
dilate=replace_stride_with_dilation[1]) | ||
self.layer4 = self._make_layer(block, | ||
512, | ||
layers[3], | ||
stride=2, | ||
dilate=replace_stride_with_dilation[2]) | ||
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) | ||
self.dropout = nn.Dropout(p=dropout, inplace=True) | ||
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) | ||
self.features = nn.BatchNorm1d(num_features, eps=1e-05) | ||
nn.init.constant_(self.features.weight, 1.0) | ||
self.features.weight.requires_grad = False | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.normal_(m.weight, 0, 0.1) | ||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||
nn.init.constant_(m.weight, 1) | ||
nn.init.constant_(m.bias, 0) | ||
|
||
if zero_init_residual: | ||
for m in self.modules(): | ||
if isinstance(m, IBasicBlock): | ||
nn.init.constant_(m.bn2.weight, 0) | ||
|
||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | ||
downsample = None | ||
previous_dilation = self.dilation | ||
if dilate: | ||
self.dilation *= stride | ||
stride = 1 | ||
if stride != 1 or self.inplanes != planes * block.expansion: | ||
downsample = nn.Sequential( | ||
conv1x1(self.inplanes, planes * block.expansion, stride), | ||
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), | ||
) | ||
layers = [] | ||
layers.append( | ||
block(self.inplanes, planes, stride, downsample, self.groups, | ||
self.base_width, previous_dilation)) | ||
self.inplanes = planes * block.expansion | ||
for _ in range(1, blocks): | ||
layers.append( | ||
block(self.inplanes, | ||
planes, | ||
groups=self.groups, | ||
base_width=self.base_width, | ||
dilation=self.dilation)) | ||
|
||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
with torch.cuda.amp.autocast(self.fp16): | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.prelu(x) | ||
x = self.layer1(x) | ||
x = self.layer2(x) | ||
x = self.layer3(x) | ||
x = self.layer4(x) | ||
x = self.bn2(x) | ||
x = torch.flatten(x, 1) | ||
x = self.dropout(x) | ||
x = self.fc(x.float() if self.fp16 else x) | ||
x = self.features(x) | ||
return x | ||
|
||
|
||
def _iresnet(arch, block, layers, pretrained, progress, **kwargs): | ||
model = IResNet(block, layers, **kwargs) | ||
if pretrained: | ||
raise ValueError() | ||
return model | ||
|
||
|
||
def iresnet18(pretrained=False, progress=True, **kwargs): | ||
return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, | ||
progress, **kwargs) | ||
|
||
|
||
def iresnet34(pretrained=False, progress=True, **kwargs): | ||
return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, | ||
progress, **kwargs) | ||
|
||
|
||
def iresnet50(pretrained=False, progress=True, **kwargs): | ||
return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, | ||
progress, **kwargs) | ||
|
||
|
||
def iresnet100(pretrained=False, progress=True, **kwargs): | ||
return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, | ||
progress, **kwargs) | ||
|
||
|
||
def iresnet200(pretrained=False, progress=True, **kwargs): | ||
return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, | ||
progress, **kwargs) |
Oops, something went wrong.