Skip to content

Commit

Permalink
add requirements.txt
Browse files Browse the repository at this point in the history
  • Loading branch information
yurui committed Sep 20, 2021
1 parent 88ba391 commit 401fe6e
Show file tree
Hide file tree
Showing 20 changed files with 2,071 additions and 6 deletions.
41 changes: 35 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

The source code of the ICCV2021 paper "[PIRenderer: Controllable Portrait Image Generation via Semantic Neural Rendering](https://arxiv.org/abs/2109.08379)" (ICCV2021)

The proposed **PIRenderer** can synthesis portrait images by intuitively controlling the face motions with fully disentangled 3DMM parameters. This model can be applied :
The proposed **PIRenderer** can synthesis portrait images by intuitively controlling the face motions with fully disentangled 3DMM parameters. This model can be applied to tasks such as:

* **Intuitive Portrait Image Editing**

Expand Down Expand Up @@ -77,8 +77,9 @@ Coming soon
# 1. Create a conda virtual environment.
conda create -n PIRenderer python=3.6
conda activate PIRenderer
conda install -c pytorch pytorch=1.7.1 torchvision cudatoolkit=10.2

# 2. Install dependency
# 2. Install other dependencies
pip install -r requirements.txt
```

Expand All @@ -88,10 +89,10 @@ We train our model using the [VoxCeleb](https://arxiv.org/abs/1706.08612). You c

#### Download the demo dataset

You can download the demo dataset with the following code:
The demo dataset contains all 514 test videos. You can download the dataset with the following code:

```bash
./download_dataset.sh
./scripts/download_demo_dataset.sh
```

#### Prepare the dataset
Expand All @@ -100,21 +101,49 @@ You can download the demo dataset with the following code:

2. After obtaining the VoxCeleb videos, we extract 3DMM parameters using [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction).

The folder are with format as:

```
${DATASET_ROOT_FOLDER}
└───path_to_videos
└───train
└───xxx.mp4
└───xxx.mp4
...
└───test
└───xxx.mp4
└───xxx.mp4
...
└───path_to_3dmm_coeff
└───train
└───xxx.mat
└───xxx.mat
...
└───test
└───xxx.mat
└───xxx.mat
...
```

3. We save the video and 3DMM parameters in a lmdb file. Please run the following code to do this

```bash
python util.write_data_to_lmdb.py
python scripts/prepare_vox_lmdb.py \
--path path_to_videos \
--coeff_3dmm_path path_to_3dmm_coeff \
--out path_to_output_dir
```



### 3). Training and Inference

#### Inference

The trained weights can be downloaded by running the following code:

```bash
./download_weights.sh
./scripts/download_weights.sh
```

Or you can choose to download the resources with these links: coming soon. Then save the files to `./result/face`
Expand Down
68 changes: 68 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
absl-py==0.13.0
backcall==0.2.0
cachetools==4.2.2
certifi==2021.5.30
charset-normalizer==2.0.6
cycler==0.10.0
dataclasses==0.8
decorator==4.4.2
filelock==3.0.12
gdown==3.13.1
google-auth==1.35.0
google-auth-oauthlib==0.4.6
grpcio==1.40.0
idna==3.2
imageio==2.9.0
importlib-metadata==4.8.1
ipython==7.16.1
ipython-genutils==0.2.0
jedi==0.18.0
kiwisolver==1.3.1
lmdb==1.2.1
Markdown==3.3.4
matplotlib==3.3.4
mkl-fft==1.3.0
mkl-random==1.1.1
mkl-service==2.3.0
networkx==2.5.1
numpy==1.19.2
oauthlib==3.1.1
olefile==0.46
opencv-python==4.5.3.56
parso==0.8.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.3.1
pip==21.2.2
prompt-toolkit==3.0.20
protobuf==3.18.0
ptyprocess==0.7.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
Pygments==2.10.0
pyparsing==2.4.7
PySocks==1.7.1
python-dateutil==2.8.2
PyWavelets==1.1.1
PyYAML==5.4.1
requests==2.26.0
requests-oauthlib==1.3.0
rsa==4.7.2
scikit-image==0.17.2
scipy==1.5.4
setuptools==58.0.4
six==1.16.0
tensorboard==2.6.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tifffile==2020.9.3
torch==1.7.1
torchvision==0.8.2
tqdm==4.62.2
traitlets==4.3.3
typing-extensions==3.10.0.2
urllib3==1.26.6
wcwidth==0.2.5
Werkzeug==2.0.1
wheel==0.37.0
zipp==3.5.0
156 changes: 156 additions & 0 deletions scripts/prepare_vox_lmdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import os
import cv2
import lmdb
import argparse
import multiprocessing
import numpy as np

from glob import glob
from io import BytesIO
from tqdm import tqdm
from PIL import Image
from scipy.io import loadmat
from torchvision.transforms import functional as trans_fn

def format_for_lmdb(*args):
key_parts = []
for arg in args:
if isinstance(arg, int):
arg = str(arg).zfill(7)
key_parts.append(arg)
return '-'.join(key_parts).encode('utf-8')

class Resizer:
def __init__(self, size, kp_root, coeff_3dmm_root, img_format):
self.size = size
self.kp_root = kp_root
self.coeff_3dmm_root = coeff_3dmm_root
self.img_format = img_format

def get_resized_bytes(self, img, img_format='jpeg'):
img = trans_fn.resize(img, (self.size, self.size), interpolation=Image.BICUBIC)
buf = BytesIO()
img.save(buf, format=img_format)
img_bytes = buf.getvalue()
return img_bytes

def prepare(self, filename):
frames = {'img':[], 'kp':None, 'coeff_3dmm':None}
cap = cv2.VideoCapture(filename)
while cap.isOpened():
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img_pil = Image.fromarray(frame)
img_bytes = self.get_resized_bytes(img_pil, self.img_format)
frames['img'].append(img_bytes)
else:
break
cap.release()
video_name = os.path.splitext(os.path.basename(filename))[0]
keypoint_byte = get_others(self.kp_root, video_name, 'keypoint')
coeff_3dmm_byte = get_others(self.coeff_3dmm_root, video_name, 'coeff_3dmm')
frames['kp'] = keypoint_byte
frames['coeff_3dmm'] = coeff_3dmm_byte
return frames

def __call__(self, index_filename):
index, filename = index_filename
result = self.prepare(filename)
return index, result, filename

def get_others(root, video_name, data_type):
if root is None:
return
else:
assert data_type in ('keypoint', 'coeff_3dmm')
if os.path.isfile(os.path.join(root, 'train', video_name+'.mat')):
file_path = os.path.join(root, 'train', video_name+'.mat')
else:
file_path = os.path.join(root, 'test', video_name+'.mat')

if data_type == 'keypoint':
return_byte = convert_kp(file_path)
else:
return_byte = convert_3dmm(file_path)
return return_byte

def convert_kp(file_path):
file_mat = loadmat(file_path)
kp_byte = file_mat['landmark'].tobytes()
return kp_byte

def convert_3dmm(file_path):
file_mat = loadmat(file_path)
coeff_3dmm = file_mat['coeff']
crop_param = file_mat['transform_params']
_, _, ratio, t0, t1 = np.hsplit(crop_param.astype(np.float32), 5)
crop_param = np.concatenate([ratio, t0, t1], 1)
coeff_3dmm_cat = np.concatenate([coeff_3dmm, crop_param], 1)
coeff_3dmm_byte = coeff_3dmm_cat.tobytes()
return coeff_3dmm_byte


def prepare_data(path, keypoint_path, coeff_3dmm_path, out, n_worker, sizes, chunksize, img_format):
filenames = list()
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
extensions = VIDEO_EXTENSIONS
for ext in extensions:
filenames += glob(f'{path}/**/*.{ext}', recursive=True)
train_video, test_video = [], []
for item in filenames:
if "/train/" in item:
train_video.append(item)
else:
test_video.append(item)
print(len(train_video), len(test_video))
with open(os.path.join(out, 'train_list.txt'),'w') as f:
for item in train_video:
item = os.path.splitext(os.path.basename(item))[0]
f.write(item + '\n')

with open(os.path.join(out, 'test_list.txt'),'w') as f:
for item in test_video:
item = os.path.splitext(os.path.basename(item))[0]
f.write(item + '\n')


filenames = sorted(filenames)
total = len(filenames)
os.makedirs(out, exist_ok=True)
for size in sizes:
lmdb_path = os.path.join(out, str(size))
with lmdb.open(lmdb_path, map_size=1024 ** 4, readahead=False) as env:
with env.begin(write=True) as txn:
txn.put(format_for_lmdb('length'), format_for_lmdb(total))
resizer = Resizer(size, keypoint_path, coeff_3dmm_path, img_format)
with multiprocessing.Pool(n_worker) as pool:
for idx, result, filename in tqdm(
pool.imap_unordered(resizer, enumerate(filenames), chunksize=chunksize),
total=total):
filename = os.path.basename(filename)
video_name = os.path.splitext(filename)[0]
txn.put(format_for_lmdb(video_name, 'length'), format_for_lmdb(len(result['img'])))

for frame_idx, frame in enumerate(result['img']):
txn.put(format_for_lmdb(video_name, frame_idx), frame)

if result['kp']:
txn.put(format_for_lmdb(video_name, 'keypoint'), result['kp'])
if result['coeff_3dmm']:
txn.put(format_for_lmdb(video_name, 'coeff_3dmm'), result['coeff_3dmm'])


if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--path', type=str, help='a path to input directiory')
parser.add_argument('--keypoint_path', type=str, help='a path to output directory', default=None)
parser.add_argument('--coeff_3dmm_path', type=str, help='a path to output directory', default=None)
parser.add_argument('--out', type=str, help='a path to output directory')
parser.add_argument('--sizes', type=int, nargs='+', default=(256,))
parser.add_argument('--n_worker', type=int, help='number of worker processes', default=8)
parser.add_argument('--chunksize', type=int, help='approximate chunksize for each worker', default=10)
parser.add_argument('--img_format', type=str, default='jpeg')
args = parser.parse_args()
prepare_data(**vars(args))
Empty file.
60 changes: 60 additions & 0 deletions third_part/PerceptualSimilarity/models/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
import torch
from torch.autograd import Variable
from pdb import set_trace as st
from IPython import embed

class BaseModel():
def __init__(self):
pass;

def name(self):
return 'BaseModel'

def initialize(self, use_gpu=True):
self.use_gpu = use_gpu
self.Tensor = torch.cuda.FloatTensor if self.use_gpu else torch.Tensor
# self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

def forward(self):
pass

def get_image_paths(self):
pass

def optimize_parameters(self):
pass

def get_current_visuals(self):
return self.input

def get_current_errors(self):
return {}

def save(self, label):
pass

# helper saving function that can be used by subclasses
def save_network(self, network, path, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(path, save_filename)
torch.save(network.state_dict(), save_path)

# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label):
# embed()
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
print('Loading network from %s'%save_path)
network.load_state_dict(torch.load(save_path))

def update_learning_rate():
pass

def get_image_paths(self):
return self.image_paths

def save_done(self, flag=False):
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')

Loading

0 comments on commit 401fe6e

Please sign in to comment.