Skip to content

Commit d9a3393

Browse files
committed
Update test code
1 parent 062683f commit d9a3393

File tree

5 files changed

+122
-40
lines changed

5 files changed

+122
-40
lines changed

main_end2end.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,15 @@
3535
parser.add_argument('--jpg', type=str, default='{}.jpg'.format(default_head_name))
3636
parser.add_argument('--close_input_face_mouth', default=CLOSE_INPUT_FACE_MOUTH, action='store_true')
3737

38-
3938
parser.add_argument('--load_AUTOVC_name', type=str, default='examples/ckpt/ckpt_autovc.pth')
40-
parser.add_argument('--load_a2l_G_name', type=str, default='examples/ckpt/ckpt_a2l_db_e_875.pth')
41-
parser.add_argument('--load_a2l_C_name', type=str, default='examples/ckpt/ckpt_audio2landmark_c.pth')
39+
parser.add_argument('--load_a2l_G_name', type=str, default='examples/ckpt/ckpt_speaker_branch.pth')
40+
parser.add_argument('--load_a2l_C_name', type=str, default='examples/ckpt/ckpt_content_branch.pth') #ckpt_audio2landmark_c.pth')
4241
parser.add_argument('--load_G_name', type=str, default='examples/ckpt/ckpt_116_i2i_comb.pth') #ckpt_image2image.pth') #ckpt_i2i_finetune_150.pth') #c
4342

4443
parser.add_argument('--amp_lip_x', type=float, default=2.)
4544
parser.add_argument('--amp_lip_y', type=float, default=2.)
46-
parser.add_argument('--amp_pos', type=float, default=1.)
45+
parser.add_argument('--amp_pos', type=float, default=.5)
4746
parser.add_argument('--reuse_train_emb_list', type=str, nargs='+', default=[]) # ['iWeklsXc0H8']) #['45hn7-LXDX8']) #['E_kmpT-EfOg']) #'iWeklsXc0H8', '29k8RtSUjE0', '45hn7-LXDX8',
48-
# --reuse_train_emb_list 45hn7-LXDX8
49-
5047
parser.add_argument('--add_audio_in', default=False, action='store_true')
5148
parser.add_argument('--comb_fan_awing', default=False, action='store_true')
5249
parser.add_argument('--output_folder', type=str, default='examples')
@@ -84,7 +81,7 @@
8481

8582

8683
''' Additional manual adjustment to input face landmarks (slimmer lips and wider eyes) '''
87-
shape_3d[48:, 0] = (shape_3d[48:, 0] - np.mean(shape_3d[48:, 0])) * 0.95 + np.mean(shape_3d[48:, 0])
84+
# shape_3d[48:, 0] = (shape_3d[48:, 0] - np.mean(shape_3d[48:, 0])) * 0.95 + np.mean(shape_3d[48:, 0])
8885
shape_3d[49:54, 1] += 1.
8986
shape_3d[55:60, 1] -= 1.
9087
shape_3d[[37,38,43,44], 1] -=2

src/approaches/train_audio2landmark.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, opt_parser, jpg_shape=None):
9797
print(key)
9898
print('====================================')
9999

100-
def __train_face_and_pos__(self, fls, aus, embs, face_id, smooth_win=31, close_mouth_ratio=.66):
100+
def __train_face_and_pos__(self, fls, aus, embs, face_id, smooth_win=31, close_mouth_ratio=.99):
101101

102102
fls_without_traj = fls[:, 0, :].detach().clone().requires_grad_(False)
103103

@@ -107,7 +107,7 @@ def __train_face_and_pos__(self, fls, aus, embs, face_id, smooth_win=31, close_m
107107
baseline_face_id = face_id.detach()
108108

109109
z = torch.tensor(torch.zeros(aus.shape[0], 128), requires_grad=False, dtype=torch.float).to(device)
110-
fl_dis_pred, _, spk_encode = self.G(aus, embs * 3.0, face_id, fls_without_traj, z, add_z_spk=True)
110+
fl_dis_pred, _, spk_encode = self.G(aus, embs * 3.0, face_id, fls_without_traj, z, add_z_spk=False)
111111

112112
# ADD CONTENT
113113
from scipy.signal import savgol_filter
@@ -133,16 +133,30 @@ def __train_face_and_pos__(self, fls, aus, embs, face_id, smooth_win=31, close_m
133133

134134
# ''' CALIBRATION '''
135135
baseline_pred_fls, _ = self.C(aus[:, 0:18, :], residual_face_id)
136+
baseline_pred_fls = self.__calib_baseline_pred_fls__(baseline_pred_fls)
137+
fl_dis_pred += baseline_pred_fls
138+
139+
return fl_dis_pred, face_id[0:1, :]
140+
141+
def __calib_baseline_pred_fls_old_(self, baseline_pred_fls, residual_face_id, aus):
136142
mean_face_id = torch.mean(baseline_pred_fls.detach(), dim=0, keepdim=True)
137143
residual_face_id -= mean_face_id.view(1, 204) * 1.
138-
# ''' ======================== '''
139-
140-
baseline_pred_fls, _ = self.C(aus[:, 0:18, :], residual_face_id)
144+
baseline_pred_fls, _ = self.C(aus, residual_face_id)
141145
baseline_pred_fls[:, 48 * 3::3] *= self.opt_parser.amp_lip_x # mouth x
142146
baseline_pred_fls[:, 48 * 3 + 1::3] *= self.opt_parser.amp_lip_y # mouth y
143-
fl_dis_pred += baseline_pred_fls
144-
145-
return fl_dis_pred, face_id[0:1, :]
147+
return baseline_pred_fls
148+
149+
def __calib_baseline_pred_fls__(self, baseline_pred_fls, ratio=0.5):
150+
np_fl_dis_pred = baseline_pred_fls.detach().cpu().numpy()
151+
K = int(np_fl_dis_pred.shape[0] * ratio)
152+
for calib_i in range(204):
153+
min_k_idx = np.argpartition(np_fl_dis_pred[:, calib_i], K)
154+
m = np.mean(np_fl_dis_pred[min_k_idx[:K], calib_i])
155+
np_fl_dis_pred[:, calib_i] = np_fl_dis_pred[:, calib_i] - m
156+
baseline_pred_fls = torch.tensor(np_fl_dis_pred, requires_grad=False).to(device)
157+
baseline_pred_fls[:, 48 * 3::3] *= self.opt_parser.amp_lip_x # mouth x
158+
baseline_pred_fls[:, 48 * 3 + 1::3] *= self.opt_parser.amp_lip_y # mouth y
159+
return baseline_pred_fls
146160

147161
def __train_pass__(self, au_emb=None, centerize_face=False, no_y_rotation=False, vis_fls=False):
148162

src/dataset/audio2landmark/audio2landmark_dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def my_collate_in_segments(self, batch):
7979
return fls, aus, embs
8080

8181
def my_collate_in_segments_noemb(self, batch):
82-
fls, aus, embs = [], [], []
82+
fls, aus = [], []
8383
for fl, au in batch:
8484
fl_data, au_data = fl[0], au[0]
8585
assert (fl_data.shape[0] == au_data.shape[0])
@@ -229,7 +229,7 @@ def __init__(self, dump_dir, dump_name, num_window_frames, num_window_step, stat
229229
# print('SAVE!')
230230

231231

232-
au_mean_std = np.loadtxt('dataset/utils/MEAN_STD_AUTOVC_RETRAIN_MEL_AU.txt') # np.mean(self.au_data[0][0]), np.std(self.au_data[0][0])
232+
au_mean_std = np.loadtxt('src/dataset/utils/MEAN_STD_AUTOVC_RETRAIN_MEL_AU.txt') # np.mean(self.au_data[0][0]), np.std(self.au_data[0][0])
233233
au_mean, au_std = au_mean_std[0:au_mean_std.shape[0]//2], au_mean_std[au_mean_std.shape[0]//2:]
234234

235235
self.au_data = [((au - au_mean) / au_std, info) for au, info in self.au_data]

src/models/model_audio2landmark_speaker_aware.py

+85-21
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def forward(self, x, e_outputs, src_mask=None, trg_mask=None):
226226
return self.norm(x)
227227

228228

229-
class Audio2landmark_speaker_aware(nn.Module):
229+
class Audio2landmark_speaker_aware_old(nn.Module):
230230

231231
def __init__(self, spk_emb_enc_size=128,
232232
transformer_d_model=32, N=2, heads=2,
@@ -291,7 +291,7 @@ def __init__(self, spk_emb_enc_size=128,
291291
)
292292

293293

294-
def forward(self, au, face_id, add_z_spk=False):
294+
def forward(self, au, face_id):
295295

296296
''' original version '''
297297
# audio
@@ -321,6 +321,87 @@ def forward(self, au, face_id, add_z_spk=False):
321321
return fl_pred, pos_pred, face_id[0:1, :], None
322322

323323

324+
class Audio2landmark_speaker_aware(nn.Module):
325+
326+
def __init__(self, audio_feat_size=80, c_enc_hidden_size=256, num_layers=3, drop_out=0,
327+
spk_feat_size=256, spk_emb_enc_size=128, lstm_g_win_size=64, add_info_size=6,
328+
transformer_d_model=32, N=2, heads=2, z_size=128, audio_dim=256):
329+
super(Audio2landmark_speaker_aware, self).__init__()
330+
331+
self.lstm_g_win_size = lstm_g_win_size
332+
self.add_info_size = add_info_size
333+
comb_mlp_size = c_enc_hidden_size * 2
334+
335+
self.audio_content_encoder = nn.LSTM(input_size=audio_feat_size,
336+
hidden_size=c_enc_hidden_size,
337+
num_layers=num_layers,
338+
dropout=drop_out,
339+
bidirectional=False,
340+
batch_first=True)
341+
342+
self.use_audio_projection = not (audio_dim == c_enc_hidden_size)
343+
if(self.use_audio_projection):
344+
self.audio_projection = nn.Sequential(
345+
nn.Linear(in_features=c_enc_hidden_size, out_features=256),
346+
nn.LeakyReLU(0.02),
347+
nn.Linear(256, 128),
348+
nn.LeakyReLU(0.02),
349+
nn.Linear(128, audio_dim),
350+
)
351+
352+
353+
''' original version '''
354+
self.spk_emb_encoder = nn.Sequential(
355+
nn.Linear(in_features=spk_feat_size, out_features=256),
356+
nn.LeakyReLU(0.02),
357+
nn.Linear(256, 128),
358+
nn.LeakyReLU(0.02),
359+
nn.Linear(128, spk_emb_enc_size),
360+
)
361+
362+
d_model = transformer_d_model * heads
363+
N = N
364+
heads = heads
365+
366+
self.encoder = Encoder(d_model, N, heads, in_size=audio_dim + spk_emb_enc_size + z_size)
367+
self.decoder = Decoder(d_model, N, heads, in_size=204)
368+
self.out = nn.Sequential(
369+
nn.Linear(in_features=d_model + z_size, out_features=512),
370+
nn.LeakyReLU(0.02),
371+
nn.Linear(512, 256),
372+
nn.LeakyReLU(0.02),
373+
nn.Linear(256, 204),
374+
)
375+
376+
377+
def forward(self, au, emb, face_id, add_z_spk=False, another_emb=None):
378+
379+
# audio
380+
audio_encode, (_, _) = self.audio_content_encoder(au)
381+
audio_encode = audio_encode[:, -1, :]
382+
383+
if(self.use_audio_projection):
384+
audio_encode = self.audio_projection(audio_encode)
385+
386+
# spk
387+
spk_encode = self.spk_emb_encoder(emb)
388+
if(add_z_spk):
389+
z_spk = torch.tensor(torch.randn(spk_encode.shape)*0.01, requires_grad=False, dtype=torch.float).to(device)
390+
spk_encode = spk_encode + z_spk
391+
392+
# comb
393+
z = torch.tensor(torch.zeros(au.shape[0], 128), requires_grad=False, dtype=torch.float).to(device)
394+
comb_encode = torch.cat((audio_encode, spk_encode, z), dim=1)
395+
src_feat = comb_encode.unsqueeze(0)
396+
397+
e_outputs = self.encoder(src_feat)[0]
398+
399+
e_outputs = torch.cat((e_outputs, z), dim=1)
400+
401+
fl_pred = self.out(e_outputs)
402+
403+
return fl_pred, face_id[0:1, :], spk_encode
404+
324405

325406

326407
def nopeak_mask(size):
@@ -344,23 +425,6 @@ def create_masks(src, trg):
344425
return src_mask, trg_mask
345426

346427

347-
class TalkingToon_spk2res_lstmgan_DL(nn.Module):
348-
def __init__(self, comb_emb_size=256, input_size=6):
349-
super(TalkingToon_spk2res_lstmgan_DL, self).__init__()
350-
351-
self.fl_D = nn.Sequential(
352-
nn.Linear(in_features=FACE_ID_FEAT_SIZE, out_features=512),
353-
nn.LeakyReLU(0.02),
354-
nn.Linear(512, 256),
355-
nn.LeakyReLU(0.02),
356-
nn.Linear(256, 1),
357-
)
358-
359-
def forward(self, feat):
360-
d = self.fl_D(feat)
361-
# d = torch.sigmoid(d)
362-
return d
363-
364428

365429
class Transformer_DT(nn.Module):
366430
def __init__(self, transformer_d_model=32, N=2, heads=2, spk_emb_enc_size=128):
@@ -375,11 +439,11 @@ def __init__(self, transformer_d_model=32, N=2, heads=2, spk_emb_enc_size=128):
375439
nn.Linear(256, 1),
376440
)
377441

378-
def forward(self, fls, spk_emb, win_size=64, win_step=1):
442+
def forward(self, fls, spk_emb, win_size=64, win_step=16):
379443
feat = torch.cat((fls, spk_emb), dim=1)
380444

381445
win_size = feat.shape[0]-1 if feat.shape[0] <= win_size else win_size
382-
D_input = [feat[i:i+win_size:win_step] for i in range(0, feat.shape[0]-win_size)]
446+
D_input = [feat[i:i+win_size:win_step] for i in range(0, feat.shape[0]-win_size, win_step)]
383447
D_input = torch.stack(D_input, dim=0)
384448
D_output = self.encoder(D_input)
385449
D_output = torch.max(D_output, dim=1, keepdim=False)[0]

util/vis.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def draw_curve(idx_list, color=(0, 255, 0), loop=False, lineWidth=linewidth):
102102

103103
class Vis_old():
104104

105-
def __init__(self, run_name, pred_fl_filename, audio_filename, av_name='NAME', fps=100, frames=625, postfix='', root_dir=r'E:\Dataset\TalkingToon\Obama', ifsmooth=True):
105+
def __init__(self, run_name, pred_fl_filename, audio_filename, av_name='NAME', fps=100, frames=625,
106+
postfix='', root_dir=r'E:\Dataset\TalkingToon\Obama', ifsmooth=True, rand_start=0):
106107

107108
print(root_dir)
108109
self.src_dir = os.path.join(root_dir, r'nn_result/{}'.format(run_name))
@@ -140,13 +141,19 @@ def __init__(self, run_name, pred_fl_filename, audio_filename, av_name='NAME', f
140141
# out = out.overwrite_output().global_args('-loglevel', 'quiet')
141142
# out.run()
142143

144+
os.system('ffmpeg -y -loglevel error -i {} -ss {} {}'.format(
145+
ain, rand_start/62.5,
146+
os.path.join(self.src_dir, '{}_a_tmp.wav'.format(av_name))
147+
))
148+
143149
os.system('ffmpeg -y -loglevel error -i {} -i {} -pix_fmt yuv420p -strict -2 -shortest {}'.format(
144150
os.path.join(self.src_dir, 'tmp.mp4'),
145-
ain,
151+
os.path.join(self.src_dir, '{}_a_tmp.wav'.format(av_name)),
146152
os.path.join(self.src_dir, '{}_av.mp4'.format(av_name))
147153
))
148154

149155
os.remove(os.path.join(self.src_dir, 'tmp.mp4'))
156+
os.remove(os.path.join(self.src_dir, '{}_a_tmp.wav'.format(av_name)))
150157

151158
# os.remove(os.path.join(self.src_dir, filename))
152159
# exit(0)

0 commit comments

Comments
 (0)