-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
31 changed files
with
2,995 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,147 @@ | ||
# vec2face | ||
<div align="center"> | ||
|
||
# VEC2FACE: SCALING FACE DATASET GENERATION | ||
|
||
[Haiyu Wu](https://haiyuwu.netlify.app/)<sup>1</sup>   [Jaskirat Singh](https://1jsingh.github.io/)<sup>2</sup>   [Sicong Tian](https://github.com/sicongT)<sup>3</sup> | ||
|
||
[Liang Zheng](https://zheng-lab.cecs.anu.edu.au/)<sup>2</sup>   [Kevin W. Bowyer](https://www3.nd.edu/~kwb/)<sup>1</sup>   | ||
|
||
<sup>1</sup>University of Notre Dame<br> | ||
<sup>2</sup>The Australian National University<br> | ||
<sup>3</sup>Indiana University South Bend | ||
|
||
[//]: # (TODO) | ||
[//]: # (<a href=''><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-green'></a>) | ||
<a href='https://haiyuwu.github.io/vec2face.github.io/'><img src='https://img.shields.io/badge/Project-Page-blue'></a> | ||
<a href=''><img src='https://img.shields.io/badge/Paper-arXiv-red'></a> | ||
<a href='https://huggingface.co/BooBooWu/Vec2Face'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-orange'></a> | ||
|
||
</div> | ||
|
||
This is the official implementation of **[Vec2Face](https://haiyuwu.github.io/vec2face.github.io/)**, an ID and attribute controllable face dataset generation model: | ||
|
||
 ✅ that generates face images purely based on the given image features<br> | ||
 ✅ that achieves the state-of-the-art performance in five standard test sets among synthetic datasets<br> | ||
 ✅ that first achieves higher accuracy than the same-scale real dataset (on CALFW)<br> | ||
 ✅ that can easily scale the dataset size to 10M<br> | ||
|
||
[//]: # (TODO) | ||
<img src='assets/teaser_figure.png'> | ||
|
||
# News/Updates | ||
- [2024/09/01] 🔥 We release Vec2Face! | ||
|
||
# :wrench: Installation | ||
```bash | ||
conda env create -f environment.yaml | ||
``` | ||
|
||
# Download Model Weights | ||
1) The weights of the Vec2Face model and estimators can be downloaded manually from [HuggingFace](https://huggingface.co/BooBooWu/Vec2Face) or using python: | ||
```python | ||
from huggingface_hub import hf_hub_download | ||
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/6DRepNet_300W_LP_AFLW2000.pth", local_dir="./") | ||
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/arcface-r100-glint360k.pth", local_dir="./") | ||
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/magface-r100-glint360k.pth", local_dir="./") | ||
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/vec2face_generator.pth", local_dir="./") | ||
``` | ||
2) The weights of the FR models trained with HSFace (10k, 20k, 100k, 200k) can be downloaded manually from [HuggingFace](https://huggingface.co/BooBooWu/Vec2Face) or using python: | ||
```python | ||
from huggingface_hub import hf_hub_download | ||
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="fr_weights/hsface10k.pth", local_dir="./") | ||
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="fr_weights/hsface20k.pth", local_dir="./") | ||
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="fr_weights/hsface100k.pth", local_dir="./") | ||
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="fr_weights/hsface200k.pth", local_dir="./") | ||
``` | ||
|
||
# Download Datasets | ||
1) The dataset used for **Vec2Face training** can be downloaded from manually from [HuggingFace](https://huggingface.co/BooBooWu/Vec2Face) or using python: | ||
```python | ||
from huggingface_hub import hf_hub_download | ||
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="lmdb_dataset/WebFace4M/WebFace4M.lmdb", local_dir="./") | ||
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="lmdb_dataset/WebFace4M/50000_ids_1022444_ims.npy", local_dir="./") | ||
``` | ||
2) The generated synthetic datasets (HSFace10k and HSFace20k for now) can be downloaded manually from [HuggingFace](https://huggingface.co/BooBooWu/Vec2Face) or using python: | ||
```python | ||
from huggingface_hub import hf_hub_download | ||
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="hsfaces/hsface10k.lmdb", local_dir="./") | ||
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="hsfaces/hsface20k.lmdb", local_dir="./") | ||
``` | ||
|
||
# ⚡Image Generation | ||
Image generation with sampled identity features: | ||
```python | ||
python image_generation.py \ | ||
--model_weights weights/vec2face_generator.pth \ | ||
--batch_size 5 \ | ||
--example 1 \ | ||
--start_end 0:10 \ | ||
--name test \ | ||
--center_feature center_features.npy | ||
``` | ||
Image generation with target yaw angle: | ||
```python | ||
python pose_image_generation.py \ | ||
--model_weights weights/vec2face_generator.pth \ | ||
--batch_size 5 \ | ||
--example 1 \ | ||
--start_end 0:10 \ | ||
--center_feature center_features.npy \ | ||
--name test \ | ||
--pose 45 \ | ||
--image_quality 25 | ||
``` | ||
|
||
# Training | ||
## Vec2Face training | ||
We only provide the WebFace4M dataset (see [here]()) and the mask that we used for training the model, if you want to use other datasets, please referring the | ||
[prepare_training_set.py]() to convert the dataset to .lmdb. | ||
Once the dataset is ready, modifying the following code to run the training: | ||
```python | ||
torchrun --nproc_per_node=1 --node_rank=0 --master_addr="host_addr" --master_port=3333 vec2face.py \ | ||
--rep_drop_prob 0.1 \ | ||
--use_rep \ | ||
--batch_size 8 \ | ||
--model vec2face_vit_base_patch16 \ | ||
--epochs 2000 \ | ||
--warmup_epochs 5 \ | ||
--blr 4e-5 \ | ||
--output_dir workspace/pixel_generator/24_try \ | ||
--train_source ./lmdb_dataset/WebFace4M/WebFace4M.lmdb \ | ||
--mask lmdb_dataset/WebFace4M/50000_ids_1022444_ims.npy \ | ||
--accum_iter 1 | ||
``` | ||
|
||
## FR model training | ||
We borrowed the code from [SOTA-Face-Recognition-Train-and-Test](https://github.com/HaiyuWu/SOTA-Face-Recognition-Train-and-Test) to train the model. The random erasing function could be added after line 84 in [data_loader_train_lmdb.py](https://github.com/HaiyuWu/SOTA-Face-Recognition-Train-and-Test/blob/main/data/data_loader_train_lmdb.py), as shown below: | ||
```python | ||
transform = transforms.Compose( | ||
[ | ||
transforms.Resize((112, 112)), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | ||
transforms.RandomErasing() | ||
] | ||
) | ||
``` | ||
Please follow the guidance of [SOTA-Face-Recognition-Train-and-Test](https://github.com/HaiyuWu/SOTA-Face-Recognition-Train-and-Test) for the rest of training process. | ||
|
||
## TODO | ||
-[ ] HuggingFace demo | ||
-[ ] 100k and 200k datasets | ||
|
||
# Acknowledgements | ||
- Thanks to the WebFace4M creators for providing such a high-quality facial dataset❤️. | ||
- Thanks to [Hugging Face](https://huggingface.co/) for providing a handy dataset and model weight management platform❤️. | ||
|
||
# Citation | ||
If you find Vec2Face useful for your research, please consider citing us and starring😄: | ||
|
||
```bibtex | ||
@misc{wu2024vec2face, | ||
title={Vec2Face: Scaling Face Dataset Generation with Loosely Constrained Vectors}, | ||
author={Wu, Haiyu and Singh, Jaskirat and Tian, Sicong and Zheng, Liang and Bowyer, Kevin W.}, | ||
year={2024} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import numpy as np | ||
from collections import defaultdict | ||
import argparse | ||
|
||
|
||
def identity_collector(args): | ||
im_paths = np.sort(np.genfromtxt(args.image_path, str)) | ||
id_dict = defaultdict(list) | ||
id_list = [] | ||
for im_path in im_paths: | ||
im_id = im_path.split("/")[-2] | ||
id_dict[im_id].append(im_path) | ||
|
||
for i, (_, v) in enumerate(id_dict.items()): | ||
for _ in v: | ||
id_list.append(i) | ||
|
||
np.savetxt(f"{args.destination}/{args.name}_labels.txt", id_list, fmt="%d") | ||
|
||
# im_paths = np.genfromtxt(args.image_path, str) | ||
# id_dict = defaultdict(int) | ||
# for im_path in im_paths: | ||
# im_id = im_path.split("/")[-2] | ||
# id_dict[im_id] += 1 | ||
# print(np.mean(list(id_dict.values()))) | ||
# print(np.std(list(id_dict.values()))) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser( | ||
description="Collect identity labels." | ||
) | ||
parser.add_argument( | ||
"--image_path", "-im", help="A file that contains image paths.", type=str | ||
) | ||
parser.add_argument("--destination", "-d", help="destination.", type=str) | ||
parser.add_argument("--name", "-n", help="file name.", type=str) | ||
|
||
args = parser.parse_args() | ||
|
||
identity_collector(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import os | ||
from os import path, makedirs | ||
import lmdb | ||
import msgpack | ||
import numpy as np | ||
import pandas as pd | ||
from PIL import Image | ||
from os import path | ||
from torch.utils.data import DataLoader | ||
from torchvision.datasets import ImageFolder | ||
from tqdm import tqdm | ||
|
||
|
||
class ImageListRaw(ImageFolder): | ||
def __init__(self, feature_list, label_file, image_list): | ||
image_names = np.asarray(pd.read_csv(image_list, delimiter=" ", header=None)) | ||
feature_names = np.asarray(pd.read_csv(feature_list, delimiter=" ", header=None)) | ||
self.im_samples = np.sort(image_names[:, 0]) | ||
self.feat_samples = np.sort(feature_names[:, 0]) | ||
|
||
self.targets = np.loadtxt(label_file, int) | ||
self.classnum = np.max(self.targets) + 1 | ||
|
||
print(self.classnum) | ||
|
||
def __len__(self): | ||
return len(self.im_samples) | ||
|
||
def __getitem__(self, index): | ||
assert path.split(self.im_samples[index])[1][:-4] == path.split(self.feat_samples[index])[1][:-4] | ||
|
||
with open(self.im_samples[index], "rb") as f: | ||
img = f.read() | ||
with open(self.feat_samples[index], "rb") as f: | ||
feature = f.read() | ||
return img, feature, self.targets[index] | ||
|
||
|
||
class CustomRawLoader(DataLoader): | ||
def __init__(self, workers, feature_list, label_file, image_list): | ||
self._dataset = ImageListRaw(feature_list, label_file, image_list) | ||
|
||
super(CustomRawLoader, self).__init__( | ||
self._dataset, num_workers=workers, collate_fn=lambda x: x | ||
) | ||
|
||
|
||
def list2lmdb( | ||
feature_list, | ||
label_file, | ||
image_list, | ||
dest, | ||
file_name, | ||
num_workers=16, | ||
write_frequency=50000, | ||
): | ||
print("Loading dataset from %s" % image_list) | ||
data_loader = CustomRawLoader( | ||
num_workers, feature_list, label_file, image_list | ||
) | ||
name = f"{file_name}.lmdb" | ||
if not path.exists(dest): | ||
makedirs(dest) | ||
lmdb_path = path.join(dest, name) | ||
isdir = path.isdir(lmdb_path) | ||
|
||
print(f"Generate LMDB to {lmdb_path}") | ||
|
||
image_size = 112 | ||
size = len(data_loader.dataset) * image_size * image_size * 3 | ||
print(f"LMDB max size: {size}") | ||
|
||
db = lmdb.open( | ||
lmdb_path, | ||
subdir=isdir, | ||
map_size=size * 2, | ||
readonly=False, | ||
meminit=False, | ||
map_async=True, | ||
) | ||
|
||
print(len(data_loader.dataset)) | ||
txn = db.begin(write=True) | ||
for idx, data in tqdm(enumerate(data_loader)): | ||
image, feature, label = data[0] | ||
txn.put( | ||
"{}".format(idx).encode("ascii"), msgpack.dumps((image, feature, int(label))) | ||
) | ||
if idx % write_frequency == 0: | ||
print("[%d/%d]" % (idx, len(data_loader))) | ||
txn.commit() | ||
txn = db.begin(write=True) | ||
idx += 1 | ||
|
||
# finish iterating through dataset | ||
txn.commit() | ||
keys = ["{}".format(k).encode("ascii") for k in range(idx)] | ||
with db.begin(write=True) as txn: | ||
txn.put(b"__keys__", msgpack.dumps(keys)) | ||
txn.put(b"__len__", msgpack.dumps(len(keys))) | ||
txn.put(b"__classnum__", msgpack.dumps(int(data_loader.dataset.classnum))) | ||
|
||
print("Flushing database ...") | ||
db.sync() | ||
db.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--image_list", "-im", help="List of images.") | ||
parser.add_argument("--feature_list", "-f", help="List of features.") | ||
parser.add_argument("--label_file", "-l", help="Identity label file.") | ||
parser.add_argument("--workers", "-w", help="Workers number.", default=8, type=int) | ||
parser.add_argument("--dest", "-d", help="Path to save the lmdb file.") | ||
parser.add_argument("--file_name", "-n", help="lmdb file name.") | ||
args = parser.parse_args() | ||
|
||
list2lmdb( | ||
args.feature_list, | ||
args.label_file, | ||
args.image_list, | ||
args.dest, | ||
args.file_name, | ||
args.workers, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import numpy as np | ||
from collections import defaultdict | ||
|
||
|
||
if __name__ == '__main__': | ||
np.random.seed(0) | ||
percent = 0.1 | ||
dataset = "WebFace4M" | ||
image_paths = np.sort(np.genfromtxt(f"./{dataset}.txt", str)) | ||
info_with_id = defaultdict(list) | ||
for i, im_path in enumerate(image_paths): | ||
im_id = im_path.split("/")[-2] | ||
info_with_id[im_id].append(i) | ||
selected_ids = np.random.choice(list(info_with_id.keys()), 50000, replace=False) | ||
|
||
selected_im_pos = [] | ||
for selected_id in selected_ids: | ||
selected_im_pos += info_with_id[selected_id] | ||
np.save(f"./small_portion_masks/{dataset}/{50000}_ids_{len(selected_im_pos)}_ims.npy", | ||
selected_im_pos) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
model: | ||
target: pixel_generator.vec2face.taming.models.vqgan.VQModel | ||
params: | ||
embed_dim: 256 | ||
n_embed: 1024 | ||
ddconfig: | ||
double_z: False | ||
z_channels: 256 | ||
resolution: 112 | ||
in_channels: 3 | ||
out_ch: 3 | ||
ch: 128 | ||
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 | ||
num_res_blocks: 2 | ||
attn_resolutions: [16] | ||
dropout: 0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .training_loader import LMDBDataLoader |
Oops, something went wrong.