Skip to content

Commit

Permalink
vec2face
Browse files Browse the repository at this point in the history
  • Loading branch information
HaiyuWu committed Sep 1, 2024
1 parent 926f0f7 commit e0b439c
Show file tree
Hide file tree
Showing 31 changed files with 2,995 additions and 1 deletion.
148 changes: 147 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,147 @@
# vec2face
<div align="center">

# VEC2FACE: SCALING FACE DATASET GENERATION

[Haiyu Wu](https://haiyuwu.netlify.app/)<sup>1</sup> &emsp; [Jaskirat Singh](https://1jsingh.github.io/)<sup>2</sup> &emsp; [Sicong Tian](https://github.com/sicongT)<sup>3</sup>

[Liang Zheng](https://zheng-lab.cecs.anu.edu.au/)<sup>2</sup> &emsp; [Kevin W. Bowyer](https://www3.nd.edu/~kwb/)<sup>1</sup> &emsp;

<sup>1</sup>University of Notre Dame<br>
<sup>2</sup>The Australian National University<br>
<sup>3</sup>Indiana University South Bend

[//]: # (TODO)
[//]: # (<a href=''><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-green'></a>)
<a href='https://haiyuwu.github.io/vec2face.github.io/'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
<a href=''><img src='https://img.shields.io/badge/Paper-arXiv-red'></a>
<a href='https://huggingface.co/BooBooWu/Vec2Face'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-orange'></a>

</div>

This is the official implementation of **[Vec2Face](https://haiyuwu.github.io/vec2face.github.io/)**, an ID and attribute controllable face dataset generation model:

&emsp;✅ that generates face images purely based on the given image features<br>
&emsp;✅ that achieves the state-of-the-art performance in five standard test sets among synthetic datasets<br>
&emsp;✅ that first achieves higher accuracy than the same-scale real dataset (on CALFW)<br>
&emsp;✅ that can easily scale the dataset size to 10M<br>

[//]: # (TODO)
<img src='assets/teaser_figure.png'>

# News/Updates
- [2024/09/01] 🔥 We release Vec2Face!

# :wrench: Installation
```bash
conda env create -f environment.yaml
```

# Download Model Weights
1) The weights of the Vec2Face model and estimators can be downloaded manually from [HuggingFace](https://huggingface.co/BooBooWu/Vec2Face) or using python:
```python
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/6DRepNet_300W_LP_AFLW2000.pth", local_dir="./")
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/arcface-r100-glint360k.pth", local_dir="./")
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/magface-r100-glint360k.pth", local_dir="./")
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="weights/vec2face_generator.pth", local_dir="./")
```
2) The weights of the FR models trained with HSFace (10k, 20k, 100k, 200k) can be downloaded manually from [HuggingFace](https://huggingface.co/BooBooWu/Vec2Face) or using python:
```python
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="fr_weights/hsface10k.pth", local_dir="./")
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="fr_weights/hsface20k.pth", local_dir="./")
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="fr_weights/hsface100k.pth", local_dir="./")
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="fr_weights/hsface200k.pth", local_dir="./")
```

# Download Datasets
1) The dataset used for **Vec2Face training** can be downloaded from manually from [HuggingFace](https://huggingface.co/BooBooWu/Vec2Face) or using python:
```python
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="lmdb_dataset/WebFace4M/WebFace4M.lmdb", local_dir="./")
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="lmdb_dataset/WebFace4M/50000_ids_1022444_ims.npy", local_dir="./")
```
2) The generated synthetic datasets (HSFace10k and HSFace20k for now) can be downloaded manually from [HuggingFace](https://huggingface.co/BooBooWu/Vec2Face) or using python:
```python
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="hsfaces/hsface10k.lmdb", local_dir="./")
hf_hub_download(repo_id="BooBooWu/Vec2Face", filename="hsfaces/hsface20k.lmdb", local_dir="./")
```

# ⚡Image Generation
Image generation with sampled identity features:
```python
python image_generation.py \
--model_weights weights/vec2face_generator.pth \
--batch_size 5 \
--example 1 \
--start_end 0:10 \
--name test \
--center_feature center_features.npy
```
Image generation with target yaw angle:
```python
python pose_image_generation.py \
--model_weights weights/vec2face_generator.pth \
--batch_size 5 \
--example 1 \
--start_end 0:10 \
--center_feature center_features.npy \
--name test \
--pose 45 \
--image_quality 25
```

# Training
## Vec2Face training
We only provide the WebFace4M dataset (see [here]()) and the mask that we used for training the model, if you want to use other datasets, please referring the
[prepare_training_set.py]() to convert the dataset to .lmdb.
Once the dataset is ready, modifying the following code to run the training:
```python
torchrun --nproc_per_node=1 --node_rank=0 --master_addr="host_addr" --master_port=3333 vec2face.py \
--rep_drop_prob 0.1 \
--use_rep \
--batch_size 8 \
--model vec2face_vit_base_patch16 \
--epochs 2000 \
--warmup_epochs 5 \
--blr 4e-5 \
--output_dir workspace/pixel_generator/24_try \
--train_source ./lmdb_dataset/WebFace4M/WebFace4M.lmdb \
--mask lmdb_dataset/WebFace4M/50000_ids_1022444_ims.npy \
--accum_iter 1
```

## FR model training
We borrowed the code from [SOTA-Face-Recognition-Train-and-Test](https://github.com/HaiyuWu/SOTA-Face-Recognition-Train-and-Test) to train the model. The random erasing function could be added after line 84 in [data_loader_train_lmdb.py](https://github.com/HaiyuWu/SOTA-Face-Recognition-Train-and-Test/blob/main/data/data_loader_train_lmdb.py), as shown below:
```python
transform = transforms.Compose(
[
transforms.Resize((112, 112)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
transforms.RandomErasing()
]
)
```
Please follow the guidance of [SOTA-Face-Recognition-Train-and-Test](https://github.com/HaiyuWu/SOTA-Face-Recognition-Train-and-Test) for the rest of training process.

## TODO
-[ ] HuggingFace demo
-[ ] 100k and 200k datasets

# Acknowledgements
- Thanks to the WebFace4M creators for providing such a high-quality facial dataset❤️.
- Thanks to [Hugging Face](https://huggingface.co/) for providing a handy dataset and model weight management platform❤️.

# Citation
If you find Vec2Face useful for your research, please consider citing us and starring😄:

```bibtex
@misc{wu2024vec2face,
title={Vec2Face: Scaling Face Dataset Generation with Loosely Constrained Vectors},
author={Wu, Haiyu and Singh, Jaskirat and Tian, Sicong and Zheng, Liang and Bowyer, Kevin W.},
year={2024}
}
```
41 changes: 41 additions & 0 deletions Scripts/identity_label_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
from collections import defaultdict
import argparse


def identity_collector(args):
im_paths = np.sort(np.genfromtxt(args.image_path, str))
id_dict = defaultdict(list)
id_list = []
for im_path in im_paths:
im_id = im_path.split("/")[-2]
id_dict[im_id].append(im_path)

for i, (_, v) in enumerate(id_dict.items()):
for _ in v:
id_list.append(i)

np.savetxt(f"{args.destination}/{args.name}_labels.txt", id_list, fmt="%d")

# im_paths = np.genfromtxt(args.image_path, str)
# id_dict = defaultdict(int)
# for im_path in im_paths:
# im_id = im_path.split("/")[-2]
# id_dict[im_id] += 1
# print(np.mean(list(id_dict.values())))
# print(np.std(list(id_dict.values())))


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Collect identity labels."
)
parser.add_argument(
"--image_path", "-im", help="A file that contains image paths.", type=str
)
parser.add_argument("--destination", "-d", help="destination.", type=str)
parser.add_argument("--name", "-n", help="file name.", type=str)

args = parser.parse_args()

identity_collector(args)
127 changes: 127 additions & 0 deletions Scripts/prepare_training_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import os
from os import path, makedirs
import lmdb
import msgpack
import numpy as np
import pandas as pd
from PIL import Image
from os import path
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from tqdm import tqdm


class ImageListRaw(ImageFolder):
def __init__(self, feature_list, label_file, image_list):
image_names = np.asarray(pd.read_csv(image_list, delimiter=" ", header=None))
feature_names = np.asarray(pd.read_csv(feature_list, delimiter=" ", header=None))
self.im_samples = np.sort(image_names[:, 0])
self.feat_samples = np.sort(feature_names[:, 0])

self.targets = np.loadtxt(label_file, int)
self.classnum = np.max(self.targets) + 1

print(self.classnum)

def __len__(self):
return len(self.im_samples)

def __getitem__(self, index):
assert path.split(self.im_samples[index])[1][:-4] == path.split(self.feat_samples[index])[1][:-4]

with open(self.im_samples[index], "rb") as f:
img = f.read()
with open(self.feat_samples[index], "rb") as f:
feature = f.read()
return img, feature, self.targets[index]


class CustomRawLoader(DataLoader):
def __init__(self, workers, feature_list, label_file, image_list):
self._dataset = ImageListRaw(feature_list, label_file, image_list)

super(CustomRawLoader, self).__init__(
self._dataset, num_workers=workers, collate_fn=lambda x: x
)


def list2lmdb(
feature_list,
label_file,
image_list,
dest,
file_name,
num_workers=16,
write_frequency=50000,
):
print("Loading dataset from %s" % image_list)
data_loader = CustomRawLoader(
num_workers, feature_list, label_file, image_list
)
name = f"{file_name}.lmdb"
if not path.exists(dest):
makedirs(dest)
lmdb_path = path.join(dest, name)
isdir = path.isdir(lmdb_path)

print(f"Generate LMDB to {lmdb_path}")

image_size = 112
size = len(data_loader.dataset) * image_size * image_size * 3
print(f"LMDB max size: {size}")

db = lmdb.open(
lmdb_path,
subdir=isdir,
map_size=size * 2,
readonly=False,
meminit=False,
map_async=True,
)

print(len(data_loader.dataset))
txn = db.begin(write=True)
for idx, data in tqdm(enumerate(data_loader)):
image, feature, label = data[0]
txn.put(
"{}".format(idx).encode("ascii"), msgpack.dumps((image, feature, int(label)))
)
if idx % write_frequency == 0:
print("[%d/%d]" % (idx, len(data_loader)))
txn.commit()
txn = db.begin(write=True)
idx += 1

# finish iterating through dataset
txn.commit()
keys = ["{}".format(k).encode("ascii") for k in range(idx)]
with db.begin(write=True) as txn:
txn.put(b"__keys__", msgpack.dumps(keys))
txn.put(b"__len__", msgpack.dumps(len(keys)))
txn.put(b"__classnum__", msgpack.dumps(int(data_loader.dataset.classnum)))

print("Flushing database ...")
db.sync()
db.close()


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--image_list", "-im", help="List of images.")
parser.add_argument("--feature_list", "-f", help="List of features.")
parser.add_argument("--label_file", "-l", help="Identity label file.")
parser.add_argument("--workers", "-w", help="Workers number.", default=8, type=int)
parser.add_argument("--dest", "-d", help="Path to save the lmdb file.")
parser.add_argument("--file_name", "-n", help="lmdb file name.")
args = parser.parse_args()

list2lmdb(
args.feature_list,
args.label_file,
args.image_list,
args.dest,
args.file_name,
args.workers,
)
20 changes: 20 additions & 0 deletions Scripts/small_portion_training_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
from collections import defaultdict


if __name__ == '__main__':
np.random.seed(0)
percent = 0.1
dataset = "WebFace4M"
image_paths = np.sort(np.genfromtxt(f"./{dataset}.txt", str))
info_with_id = defaultdict(list)
for i, im_path in enumerate(image_paths):
im_id = im_path.split("/")[-2]
info_with_id[im_id].append(i)
selected_ids = np.random.choice(list(info_with_id.keys()), 50000, replace=False)

selected_im_pos = []
for selected_id in selected_ids:
selected_im_pos += info_with_id[selected_id]
np.save(f"./small_portion_masks/{dataset}/{50000}_ids_{len(selected_im_pos)}_ims.npy",
selected_im_pos)
Binary file added asset/teaser_figure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 16 additions & 0 deletions configs/vec2face/vqgan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
model:
target: pixel_generator.vec2face.taming.models.vqgan.VQModel
params:
embed_dim: 256
n_embed: 1024
ddconfig:
double_z: False
z_channels: 256
resolution: 112
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [16]
dropout: 0.0
1 change: 1 addition & 0 deletions dataloader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .training_loader import LMDBDataLoader
Loading

0 comments on commit e0b439c

Please sign in to comment.