forked from soubhiksanyal/RingNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
131 lines (115 loc) · 4.7 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Author: Soubhik Sanyal
Copyright (c) 2019, Soubhik Sanyal
All rights reserved.
based on github.com/akanazawa/hmr
"""
## Demo of RingNet.
## Note that RingNet requires a loose crop of the face in the image.
## Sample usage:
## Run the following command to generate check the RingNet predictions on loosely cropped face images
# python -m demo --img_path *.jpg --out_folder ./RingNet_output
## To output the meshes run the following command
# python -m demo --img_path *.jpg --out_folder ./RingNet_output --save_obj_file=True
## To output both meshes and flame parameters run the following command
# python -m demo --img_path *.jpg --out_folder ./RingNet_output --save_obj_file=True --save_flame_parameters=True
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
from absl import flags
import numpy as np
import skimage.io as io
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from psbody.mesh import Mesh
from smpl_webuser.serialization import load_model
from util import renderer as vis_util
from util import image as img_util
from config_test import get_config
from run_RingNet import RingNet_inference
def visualize(img, proc_param, verts, cam, img_name='test_image'):
"""
Renders the result in original image coordinate frame.
"""
cam_for_render, vert_shifted = vis_util.get_original(
proc_param, verts, cam, img_size=img.shape[:2])
# Render results
rend_img_overlay = renderer(
vert_shifted*1.0, cam=cam_for_render, img=img, do_alpha=True)
rend_img = renderer(
vert_shifted*1.0, cam=cam_for_render, img_size=img.shape[:2])
rend_img_vp1 = renderer.rotated(
vert_shifted, 30, cam=cam_for_render, img_size=img.shape[:2])
import matplotlib.pyplot as plt
fig = plt.figure(1)
plt.clf()
plt.subplot(221)
plt.imshow(img)
plt.title('input')
plt.axis('off')
plt.subplot(222)
plt.imshow(rend_img_overlay)
plt.title('3D Mesh overlay')
plt.axis('off')
plt.subplot(223)
plt.imshow(rend_img)
plt.title('3D mesh')
plt.axis('off')
plt.subplot(224)
plt.imshow(rend_img_vp1)
plt.title('diff vp')
plt.axis('off')
plt.draw()
plt.show(block=False)
fig.savefig(img_name + '.png')
# import ipdb
# ipdb.set_trace()
def preprocess_image(img_path):
img = io.imread(img_path)
if np.max(img.shape[:2]) != config.img_size:
print('Resizing so the max image size is %d..' % config.img_size)
scale = (float(config.img_size) / np.max(img.shape[:2]))
else:
scale = 1.0#scaling_factor
center = np.round(np.array(img.shape[:2]) / 2).astype(int)
# image center in (x,y)
center = center[::-1]
crop, proc_param = img_util.scale_and_crop(img, scale, center,
config.img_size)
# import ipdb; ipdb.set_trace()
# Normalize image to [-1, 1]
# plt.imshow(crop/255.0)
# plt.show()
crop = 2 * ((crop / 255.) - 0.5)
return crop, proc_param, img
def main(config):
sess = tf.Session()
model = RingNet_inference(config, sess=sess)
input_img, proc_param, img = preprocess_image(config.img_path)
vertices, flame_parameters = model.predict(np.expand_dims(input_img, axis=0), get_parameters=True)
cams = flame_parameters[0][:3]
visualize(img, proc_param, vertices[0], cams, img_name=config.out_folder + '/images/' + config.img_path.split('/')[-1][:-4])
if config.save_obj_file:
if not os.path.exists(config.out_folder + '/mesh'):
os.mkdir(config.out_folder + '/mesh')
template_mesh = load_model(config.flame_model_path)
mesh = Mesh(v=vertices[0], f=template_mesh.f)
mesh.write_obj(config.out_folder + '/mesh/' + config.img_path.split('/')[-1][:-4] + '.obj')
if config.save_flame_parameters:
if not os.path.exists(config.out_folder + '/params'):
os.mkdir(config.out_folder + '/params')
flame_parameters_ = {'cam': flame_parameters[0][:3], 'pose': flame_parameters[0][3:3+config.pose_params], 'shape': flame_parameters[0][3+config.pose_params:3+config.pose_params+config.shape_params],
'expression': flame_parameters[0][3+config.pose_params+config.shape_params:]}
np.save(config.out_folder + '/params/' + config.img_path.split('/')[-1][:-4] + '.npy', flame_parameters_)
if __name__ == '__main__':
config = get_config()
template_mesh = load_model(config.flame_model_path)
renderer = vis_util.SMPLRenderer(faces=template_mesh.f)
if not os.path.exists(config.out_folder):
os.makedirs(config.out_folder)
if not os.path.exists(config.out_folder + '/images'):
os.mkdir(config.out_folder + '/images')
main(config)