Skip to content

Commit

Permalink
TransFace Code (modelscope#525)
Browse files Browse the repository at this point in the history
* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Delete face_module/TransFace/README.md

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload
  • Loading branch information
DanJun6737 authored Mar 8, 2024
1 parent 3ad2494 commit 0ce8353
Show file tree
Hide file tree
Showing 42 changed files with 5,520 additions and 1 deletion.
38 changes: 38 additions & 0 deletions face_module/TransFace/FFT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import cv2
import numpy as np
from math import sqrt

def amplitude_spectrum_mix(img1, img2, alpha, ratio=1.0): #img_src, img_random
"""Input image size: ndarray of [H, W, C]"""
lam = np.random.uniform(0, alpha)
assert img1.shape == img2.shape
h, w, c = img1.shape
h_crop = int(h * sqrt(ratio))
w_crop = int(w * sqrt(ratio))
h_start = h // 2 - h_crop // 2
w_start = w // 2 - w_crop // 2

img1_fft = np.fft.fft2(img1, axes=(0, 1))
img2_fft = np.fft.fft2(img2, axes=(0, 1))
img1_abs, img1_pha = np.abs(img1_fft), np.angle(img1_fft)
img2_abs, img2_pha = np.abs(img2_fft), np.angle(img2_fft)

img1_abs = np.fft.fftshift(img1_abs, axes=(0, 1))
img2_abs = np.fft.fftshift(img2_abs, axes=(0, 1))

img1_abs_ = np.copy(img1_abs)
img2_abs_ = np.copy(img2_abs)

img1_abs[h_start:h_start + h_crop, w_start:w_start + w_crop] = \
lam * img2_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] + (1 - lam) * img1_abs_[
h_start:h_start + h_crop,
w_start:w_start + w_crop]

img1_abs = np.fft.ifftshift(img1_abs, axes=(0, 1))
img2_abs = np.fft.ifftshift(img2_abs, axes=(0, 1))

img_src_random = img1_abs * (np.e ** (1j * img1_pha))
img_src_random = np.real(np.fft.ifft2(img_src_random, axes=(0, 1)))
img_src_random = np.uint8(np.clip(img_src_random, 0, 255))
return img_src_random

73 changes: 72 additions & 1 deletion face_module/TransFace/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,72 @@
### The code will be released in the future.
# TransFace: Calibrating Transformer Training for Face Recognition from a Data-Centric Perspective (ICCV-2023)
This is the official PyTorch implementation of [TransFace](https://openaccess.thecvf.com/content/ICCV2023/html/Dan_TransFace_Calibrating_Transformer_Training_for_Face_Recognition_from_a_Data-Centric_ICCV_2023_paper.html).

[Arxiv](https://arxiv.org/abs/2308.10133)

## ModelScope
You can quickly experience and invoke our TransFace model on the [ModelScope](https://modelscope.cn/models/damo/cv_vit_face-recognition/summary).

## Requirements
* Install Pytorch (torch>=1.9.0)
* ```pip install -r requirement.txt```

## Datasets
You can download the training datasets, including MS1MV2 and Glint360K:
* MS1MV2: [Google Drive](https://drive.google.com/file/d/1SXS4-Am3bsKSK615qbYdbA_FMVh3sAvR/view)
* Glint360K: [Baidu](https://pan.baidu.com/share/init?surl=GsYqTTt7_Dn8BfxxsLFN0w) (code=:o3az)

You can download the test dataset IJB-C as follows:
* IJB-C: [Google Drive](https://drive.google.com/file/d/1aC4zf2Bn0xCVH_ZtEuQipR2JvRb1bf8o/view)

## How to Train Models
1. You need to modify the path of training data in every configuration file in folder configs.

2. To run on a machine with 8 GPUs:
```
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=12581 train.py
```

## How to Test Models
1. You need to modify the path of IJB-C dataset in eval_ijbc.py.

2. Run:
```
python eval_ijbc.py --model-prefix work_dirs/glint360k_vit_s/model.pt --result-dir work_dirs/glint360k_vit_s --network vit_s_dp005_mask_0 > ijbc_glint360k_vit_s.log 2>&1 &
```

## TransFace Pretrained Models

You can download the TransFace models reported in our paper as follows:

| Training Data | Model | IJB-C(1e-6) | IJB-C(1e-5) | IJB-C(1e-4) | IJB-C(1e-3) | IJB-C(1e-2) | IJB-C(1e-1) |
| ------ | ------ | ------ | ------ | ------ | ------ | ------ | ------ |
| MS1MV2 | [TransFace-S](https://drive.google.com/file/d/1UZWCg7jNESDv8EWs7mxQSswCMGbAZNF4/view?usp=share_link) | 86.75 | 93.87 | 96.45 | 97.51 | 98.34 | 98.99 |
| MS1MV2 | [TransFace-B](https://drive.google.com/file/d/16O-q30mH8d3lECqa5eJd8rABaUlNhQ0K/view?usp=share_link) | 86.73 | 94.15 | 96.55 | 97.73 | 98.47 | 99.11 |
| MS1MV2 | [TransFace-L](https://drive.google.com/file/d/1uXUFT6ujEPqvCTHzONsp6-DMIc24Cc85/view?usp=share_link) | 86.90 | 94.55 | 96.59 | 97.80 | 98.45 | 99.04 |

| Training Data | Model | IJB-C(1e-6) | IJB-C(1e-5) | IJB-C(1e-4) | IJB-C(1e-3) | IJB-C(1e-2) | IJB-C(1e-1) |
| ------ | ------ | ------ | ------ | ------ | ------ | ------ | ------ |
| Glint360K | [TransFace-S](https://drive.google.com/file/d/18Zh_zMlYttKVIGArmDYNEchIvUSH5FQ1/view?usp=share_link) | 89.93 | 96.06 | 97.33 | 98.00 | 98.49 | 99.11 |
| Glint360K | [TransFace-B](https://drive.google.com/file/d/13IezvOo5GvtGVsRap2s5RVqtIl1y0ke5/view?usp=share_link) | 88.64 | 96.18 | 97.45 | 98.17 | 98.66 | 99.23 |
| Glint360K | [TransFace-L](https://drive.google.com/file/d/1jXL_tidh9KqAS6MgeinIk2UNWmEaxfb0/view?usp=share_link) | 89.71 | 96.29 | 97.61 | 98.26 | 98.64 | 99.19 |

You can test the accuracy of these model: (e.g. Glint360K TransFace-L)
```
python eval_ijbc.py --model-prefix work_dirs/glint360k_vit_l/glint360k_model_TransFace_L.pt --result-dir work_dirs/glint360k_vit_l --network vit_l_dp005_mask_005 > ijbc_glint360k_vit_l.log 2>&1 &
```

## Citation
* If you find it helpful for you, please cite our paper
```
@inproceedings{dan2023transface,
title={TransFace: Calibrating Transformer Training for Face Recognition from a Data-Centric Perspective},
author={Dan, Jun and Liu, Yang and Xie, Haoyu and Deng, Jiankang and Xie, Haoran and Xie, Xuansong and Sun, Baigui},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={20642--20653},
year={2023}
}
```

## Acknowledgments
We thank Insighface for the excellent [code base](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch).

90 changes: 90 additions & 0 deletions face_module/TransFace/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
from .mobilefacenet import get_mbf


def get_model(name, **kwargs):
# resnet
if name == "r18":
return iresnet18(False, **kwargs)
elif name == "r34":
return iresnet34(False, **kwargs)
elif name == "r50":
return iresnet50(False, **kwargs)
elif name == "r100":
return iresnet100(False, **kwargs)
elif name == "r200":
return iresnet200(False, **kwargs)
elif name == "r2060":
from .iresnet2060 import iresnet2060
return iresnet2060(False, **kwargs)

elif name == "mbf":
fp16 = kwargs.get("fp16", False)
num_features = kwargs.get("num_features", 512)
return get_mbf(fp16=fp16, num_features=num_features)

elif name == "mbf_large":
from .mobilefacenet import get_mbf_large
fp16 = kwargs.get("fp16", False)
num_features = kwargs.get("num_features", 512)
return get_mbf_large(fp16=fp16, num_features=num_features)

elif name == "vit_t":
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)

elif name == "vit_t_dp005_mask0": # For WebFace42M
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)

elif name == "vit_s":
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)

elif name == "vit_s_dp005_mask_0": # For WebFace42M
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)

elif name == "vit_b":
# this is a feature
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True)

elif name == "vit_b_dp005_mask_005": # For WebFace42M
# this is a feature
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)

elif name == "vit_l_dp005_mask_005": # For WebFace42M
# this is a feature
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)

else:
raise ValueError()





194 changes: 194 additions & 0 deletions face_module/TransFace/backbones/iresnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint

__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
using_ckpt = False

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=1,
stride=stride,
bias=False)


class IBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
groups=1, base_width=64, dilation=1):
super(IBasicBlock, self).__init__()
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
self.conv1 = conv3x3(inplanes, planes)
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
self.prelu = nn.PReLU(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
self.downsample = downsample
self.stride = stride

def forward_impl(self, x):
identity = x
out = self.bn1(x)
out = self.conv1(out)
out = self.bn2(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out

def forward(self, x):
if self.training and using_ckpt:
return checkpoint(self.forward_impl, x)
else:
return self.forward_impl(x)


class IResNet(nn.Module):
fc_scale = 7 * 7
def __init__(self,
block, layers, dropout=0, num_features=512, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
super(IResNet, self).__init__()
self.extra_gflops = 0.0
self.fp16 = fp16
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
self.prelu = nn.PReLU(self.inplanes)
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
self.layer2 = self._make_layer(block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2])
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
self.dropout = nn.Dropout(p=dropout, inplace=True)
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
nn.init.constant_(self.features.weight, 1.0)
self.features.weight.requires_grad = False

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, 0, 0.1)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

if zero_init_residual:
for m in self.modules():
if isinstance(m, IBasicBlock):
nn.init.constant_(m.bn2.weight, 0)

def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
)
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation))

return nn.Sequential(*layers)

def forward(self, x):
with torch.cuda.amp.autocast(self.fp16):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn2(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc(x.float() if self.fp16 else x)
x = self.features(x)
return x


def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
model = IResNet(block, layers, **kwargs)
if pretrained:
raise ValueError()
return model


def iresnet18(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
progress, **kwargs)


def iresnet34(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
progress, **kwargs)


def iresnet50(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
progress, **kwargs)


def iresnet100(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
progress, **kwargs)


def iresnet200(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
progress, **kwargs)
Loading

0 comments on commit 0ce8353

Please sign in to comment.