@@ -226,7 +226,7 @@ def forward(self, x, e_outputs, src_mask=None, trg_mask=None):
226
226
return self .norm (x )
227
227
228
228
229
- class Audio2landmark_speaker_aware (nn .Module ):
229
+ class Audio2landmark_speaker_aware_old (nn .Module ):
230
230
231
231
def __init__ (self , spk_emb_enc_size = 128 ,
232
232
transformer_d_model = 32 , N = 2 , heads = 2 ,
@@ -291,7 +291,7 @@ def __init__(self, spk_emb_enc_size=128,
291
291
)
292
292
293
293
294
- def forward (self , au , face_id , add_z_spk = False ):
294
+ def forward (self , au , face_id ):
295
295
296
296
''' original version '''
297
297
# audio
@@ -321,6 +321,87 @@ def forward(self, au, face_id, add_z_spk=False):
321
321
return fl_pred , pos_pred , face_id [0 :1 , :], None
322
322
323
323
324
+ class Audio2landmark_speaker_aware (nn .Module ):
325
+
326
+ def __init__ (self , audio_feat_size = 80 , c_enc_hidden_size = 256 , num_layers = 3 , drop_out = 0 ,
327
+ spk_feat_size = 256 , spk_emb_enc_size = 128 , lstm_g_win_size = 64 , add_info_size = 6 ,
328
+ transformer_d_model = 32 , N = 2 , heads = 2 , z_size = 128 , audio_dim = 256 ):
329
+ super (Audio2landmark_speaker_aware , self ).__init__ ()
330
+
331
+ self .lstm_g_win_size = lstm_g_win_size
332
+ self .add_info_size = add_info_size
333
+ comb_mlp_size = c_enc_hidden_size * 2
334
+
335
+ self .audio_content_encoder = nn .LSTM (input_size = audio_feat_size ,
336
+ hidden_size = c_enc_hidden_size ,
337
+ num_layers = num_layers ,
338
+ dropout = drop_out ,
339
+ bidirectional = False ,
340
+ batch_first = True )
341
+
342
+ self .use_audio_projection = not (audio_dim == c_enc_hidden_size )
343
+ if (self .use_audio_projection ):
344
+ self .audio_projection = nn .Sequential (
345
+ nn .Linear (in_features = c_enc_hidden_size , out_features = 256 ),
346
+ nn .LeakyReLU (0.02 ),
347
+ nn .Linear (256 , 128 ),
348
+ nn .LeakyReLU (0.02 ),
349
+ nn .Linear (128 , audio_dim ),
350
+ )
351
+
352
+
353
+ ''' original version '''
354
+ self .spk_emb_encoder = nn .Sequential (
355
+ nn .Linear (in_features = spk_feat_size , out_features = 256 ),
356
+ nn .LeakyReLU (0.02 ),
357
+ nn .Linear (256 , 128 ),
358
+ nn .LeakyReLU (0.02 ),
359
+ nn .Linear (128 , spk_emb_enc_size ),
360
+ )
361
+
362
+ d_model = transformer_d_model * heads
363
+ N = N
364
+ heads = heads
365
+
366
+ self .encoder = Encoder (d_model , N , heads , in_size = audio_dim + spk_emb_enc_size + z_size )
367
+ self .decoder = Decoder (d_model , N , heads , in_size = 204 )
368
+ self .out = nn .Sequential (
369
+ nn .Linear (in_features = d_model + z_size , out_features = 512 ),
370
+ nn .LeakyReLU (0.02 ),
371
+ nn .Linear (512 , 256 ),
372
+ nn .LeakyReLU (0.02 ),
373
+ nn .Linear (256 , 204 ),
374
+ )
375
+
376
+
377
+ def forward (self , au , emb , face_id , add_z_spk = False , another_emb = None ):
378
+
379
+ # audio
380
+ audio_encode , (_ , _ ) = self .audio_content_encoder (au )
381
+ audio_encode = audio_encode [:, - 1 , :]
382
+
383
+ if (self .use_audio_projection ):
384
+ audio_encode = self .audio_projection (audio_encode )
385
+
386
+ # spk
387
+ spk_encode = self .spk_emb_encoder (emb )
388
+ if (add_z_spk ):
389
+ z_spk = torch .tensor (torch .randn (spk_encode .shape )* 0.01 , requires_grad = False , dtype = torch .float ).to (device )
390
+ spk_encode = spk_encode + z_spk
391
+
392
+ # comb
393
+ z = torch .tensor (torch .zeros (au .shape [0 ], 128 ), requires_grad = False , dtype = torch .float ).to (device )
394
+ comb_encode = torch .cat ((audio_encode , spk_encode , z ), dim = 1 )
395
+ src_feat = comb_encode .unsqueeze (0 )
396
+
397
+ e_outputs = self .encoder (src_feat )[0 ]
398
+
399
+ e_outputs = torch .cat ((e_outputs , z ), dim = 1 )
400
+
401
+ fl_pred = self .out (e_outputs )
402
+
403
+ return fl_pred , face_id [0 :1 , :], spk_encode
404
+
324
405
325
406
326
407
def nopeak_mask (size ):
@@ -344,23 +425,6 @@ def create_masks(src, trg):
344
425
return src_mask , trg_mask
345
426
346
427
347
- class TalkingToon_spk2res_lstmgan_DL (nn .Module ):
348
- def __init__ (self , comb_emb_size = 256 , input_size = 6 ):
349
- super (TalkingToon_spk2res_lstmgan_DL , self ).__init__ ()
350
-
351
- self .fl_D = nn .Sequential (
352
- nn .Linear (in_features = FACE_ID_FEAT_SIZE , out_features = 512 ),
353
- nn .LeakyReLU (0.02 ),
354
- nn .Linear (512 , 256 ),
355
- nn .LeakyReLU (0.02 ),
356
- nn .Linear (256 , 1 ),
357
- )
358
-
359
- def forward (self , feat ):
360
- d = self .fl_D (feat )
361
- # d = torch.sigmoid(d)
362
- return d
363
-
364
428
365
429
class Transformer_DT (nn .Module ):
366
430
def __init__ (self , transformer_d_model = 32 , N = 2 , heads = 2 , spk_emb_enc_size = 128 ):
@@ -375,11 +439,11 @@ def __init__(self, transformer_d_model=32, N=2, heads=2, spk_emb_enc_size=128):
375
439
nn .Linear (256 , 1 ),
376
440
)
377
441
378
- def forward (self , fls , spk_emb , win_size = 64 , win_step = 1 ):
442
+ def forward (self , fls , spk_emb , win_size = 64 , win_step = 16 ):
379
443
feat = torch .cat ((fls , spk_emb ), dim = 1 )
380
444
381
445
win_size = feat .shape [0 ]- 1 if feat .shape [0 ] <= win_size else win_size
382
- D_input = [feat [i :i + win_size :win_step ] for i in range (0 , feat .shape [0 ]- win_size )]
446
+ D_input = [feat [i :i + win_size :win_step ] for i in range (0 , feat .shape [0 ]- win_size , win_step )]
383
447
D_input = torch .stack (D_input , dim = 0 )
384
448
D_output = self .encoder (D_input )
385
449
D_output = torch .max (D_output , dim = 1 , keepdim = False )[0 ]
0 commit comments