-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain_LawDNet_clip.py
400 lines (326 loc) · 16.2 KB
/
train_LawDNet_clip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
import sys
import cv2
import logging
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import os
import time
import torch.nn.functional as F
from tqdm import tqdm
import wandb
import yaml
import argparse
from torch.cuda.amp import autocast as autocast
from models.Discriminator import Discriminator
from models.VGG19 import Vgg19
from models.LawDNet import LawDNet
from models.Syncnet import SyncNetPerception
from utils.training_utils import get_scheduler, update_learning_rate, GANLoss
from config.config import DINetTrainingOptions
from sync_batchnorm import convert_model
from torch.utils.data import DataLoader
from dataset.dataset_DINet_clip import DINetDataset
from models.Gaussian_blur import Gaussian_bluring
from tensor_processing import SmoothSqMask
from models.content_model import AudioContentModel, LipContentModel
from torch.nn.utils import clip_grad_norm_
# 冻结BN层
def fix_bn(m):
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval()
def replace_images(fake_out, source_clip):
'''
input:
output: fakeout 的 随机 0~4 张图被 source_clip 的 0~4 张图替换
'''
# 将 fake_out 和 source_clip 克隆为新的张量
fake_out_clone = fake_out.clone()
source_clip_clone = source_clip.clone()
# 随机选择 0~4 张图的索引
num_replace = random.randint(0, 4)
indices = random.sample(range(5), num_replace)
# 将选中的张图替换到 fake_out 中
for idx in indices:
fake_out_clone[idx, :, :, :] = source_clip_clone[idx, :, :, :]
# 返回修改后的张量
return fake_out_clone
# 初始化和登录WandB
def init_wandb(name):
wandb.login()
run = wandb.init(project=name)
def load_experiment_config(config_module_path):
"""动态加载指定的配置文件"""
import importlib.util
spec = importlib.util.spec_from_file_location("config_module", config_module_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
return config_module.experiment_config
# 加载配置和设备设置
def load_config_and_device(args):
# import pdb; pdb.set_trace()
'''加载配置和设置设备'''
# 动态加载配置文件
experiment_config = load_experiment_config(args.config_path)
# 创建 opt 实例,这里避免 argparse 解析命令行参数
opt = DINetTrainingOptions().parse_args()
# 根据动态加载的配置更新 opt
for key, value in experiment_config.items():
if hasattr(opt, key):
setattr(opt, key, value)
# 假设 wandb 已经初始化
wandb.config.update(opt) # 如果使用 wandb,可以这样更新配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 根据实验名称直接修改文件夹名字
# 分割result_path,最多分割成两部分
path_parts = opt.result_path.rsplit('/', 1)
# 在倒数第一个/之前插入本次实验的名字
opt.result_path = f'{path_parts[0]}/{args.name}/{path_parts[1]}'
return opt, device
# Save configuration to a YAML file
def save_config_to_yaml(config, filename):
with open(filename, 'w') as file:
yaml.dump(config, file, default_flow_style=False)
# 加载训练数据
def load_training_data(opt):
train_data = DINetDataset(opt.train_data, opt.augment_num, opt.mouth_region_size)
training_data_loader = DataLoader(dataset=train_data, batch_size=opt.batch_size, shuffle=True, drop_last=True)
return training_data_loader
# 初始化网络
def init_networks(opt):
net_g = LawDNet(opt.source_channel, opt.ref_channel, opt.audio_channel,
opt.warp_layer_num, opt.num_kpoints, opt.coarse_grid_size).cuda()
net_dI = Discriminator(opt.source_channel, opt.D_block_expansion, opt.D_num_blocks, opt.D_max_features).cuda()
net_dV = Discriminator(opt.source_channel * 5, opt.D_block_expansion, opt.D_num_blocks, opt.D_max_features).cuda()
net_vgg = Vgg19().cuda()
net_lipsync = SyncNetPerception(opt.pretrained_syncnet_path).cuda()
device_ids = [int(x) for x in opt.cuda_devices.split(',')]
net_g = nn.DataParallel(net_g, device_ids=device_ids).to(device)
net_dI = nn.DataParallel(net_dI, device_ids=device_ids).to(device)
net_dV = nn.DataParallel(net_dV, device_ids=device_ids).to(device)
net_vgg = nn.DataParallel(net_vgg, device_ids=device_ids).to(device)
net_lipsync = nn.DataParallel(net_lipsync, device_ids=device_ids).to(device)
net_g = convert_model(net_g)
return net_g, net_dI, net_dV, net_vgg, net_lipsync
# 设置优化器
def setup_optimizers(net_g, net_dI, net_dV):
optimizer_g = optim.AdamW(net_g.parameters(), lr=opt.lr_g)
optimizer_dI = optim.AdamW(net_dI.parameters(), lr=opt.lr_dI)
optimizer_dV = optim.AdamW(net_dV.parameters(), lr=opt.lr_dI)
return optimizer_g, optimizer_dI, optimizer_dV
def load_pretrained_weights(net_g, opt):
"""
Loads the pretrained weights into the model if a valid path is provided.
Parameters:
- net_g: the model into which the weights will be loaded.
- opt: options object that contains the path to the pretrained weights.
Returns:
- A boolean value indicating whether the weights were loaded successfully.
"""
if opt.pretrained_frame_DINet_path:
try:
print(f'Loading frame trained DINet weight from: {opt.pretrained_frame_DINet_path}')
checkpoint = torch.load(opt.pretrained_frame_DINet_path)
net_g.load_state_dict(checkpoint['state_dict']['net_g'])
print('Loading frame trained DINet weight finished!')
return True
except Exception as e:
print(f'Error loading pretrained weights: {e}')
return False
else:
print("Path to pretrained frame trained DINet weight is empty.")
return False
# 设置损失函数
def setup_criterion():
criterionGAN = GANLoss().cuda()
criterionL1 = nn.L1Loss().cuda()
criterionMSE = nn.MSELoss().cuda()
criterionCosine = nn.CosineEmbeddingLoss().cuda()
return criterionGAN, criterionL1, criterionMSE, criterionCosine
# 设置学习率调度器
def setup_schedulers(optimizer_g, optimizer_dI, optimizer_dV):
net_g_scheduler = get_scheduler(optimizer_g, opt.non_decay, opt.decay)
net_dI_scheduler = get_scheduler(optimizer_dI, opt.non_decay, opt.decay)
net_dV_scheduler = get_scheduler(optimizer_dV, opt.non_decay, opt.decay)
return net_g_scheduler, net_dI_scheduler, net_dV_scheduler
def log_to_wandb(source_clip, fake_out):
source_clip = source_clip.float() # 将数据转换为全精度
fake_out = fake_out.float() # 同上
# 可视化原始source_clip
images_source = [wandb.Image(source_clip[i].cpu(), caption=f"Source Clip {i}") for i in range(source_clip.shape[0])]
wandb.log({"Source Clips": images_source})
# 可视化fake_out
images_fake_out = [wandb.Image(fake_out[i].cpu(), caption=f"Fake Out {i}") for i in range(fake_out.shape[0])]
wandb.log({"Fake Outs": images_fake_out})
# 训练过程
def train(
opt,
net_g,
net_dI,
net_dV,
training_data_loader,
optimizer_g,
optimizer_dI,
optimizer_dV,
criterionGAN,
criterionL1,
criterionMSE,
criterionCosine,
net_g_scheduler,
net_dI_scheduler,
net_dV_scheduler
):
# 混合精度训练:Creates a GradScaler once at the beginning of training.
scaler = torch.cuda.amp.GradScaler(enabled=True)
smooth_sqmask = SmoothSqMask().cuda()
for epoch in range(opt.start_epoch, opt.non_decay + opt.decay + 1):
for iteration, data in enumerate(tqdm(training_data_loader, desc=f"Epoch {epoch}")):
source_clip, reference_clip, deep_speech_clip, deep_speech_full, flag = data
# 检查是否有脏数据
flag = flag.cuda()
if not (flag.equal(torch.ones(opt.batch_size, 1, device='cuda'))):
print("跳过含有脏数据的批次")
continue
source_clip = torch.cat(torch.split(source_clip, 1, dim=1), 0).squeeze(1).float().cuda()
reference_clip = torch.cat(torch.split(reference_clip, 1, dim=1), 0).squeeze(1).float().cuda()
deep_speech_clip = torch.cat(torch.split(deep_speech_clip, 1, dim=1), 0).squeeze(1).float().cuda()
deep_speech_full = deep_speech_full.float().cuda()
# 生成mask
source_clip_mask = smooth_sqmask(source_clip)
with autocast(enabled=True):
fake_out = net_g(source_clip_mask, reference_clip, deep_speech_clip)
fake_out_half = F.avg_pool2d(fake_out, 3, 2, 1, count_include_pad=False)
source_clip_half = F.interpolate(source_clip, scale_factor=0.5, mode='bilinear')
# 更新判别器DI
optimizer_dI.zero_grad()
with autocast(enabled=True):
_, pred_fake_dI = net_dI(fake_out.detach())
loss_dI_fake = criterionGAN(pred_fake_dI, False)
_, pred_real_dI = net_dI(source_clip)
loss_dI_real = criterionGAN(pred_real_dI, True)
loss_dI = (loss_dI_fake + loss_dI_real) * 0.5
scaler.scale(loss_dI).backward(retain_graph=True)
scaler.step(optimizer_dI)
# 更新判别器DV
optimizer_dV.zero_grad()
with autocast(enabled=True):
condition_fake_dV = torch.cat(torch.split(fake_out.detach(), opt.batch_size, dim=0), 1)
_, pred_fake_dV = net_dV(condition_fake_dV)
loss_dV_fake = criterionGAN(pred_fake_dV, False)
condition_real_dV = torch.cat(torch.split(source_clip, opt.batch_size, dim=0), 1)
_, pred_real_dV = net_dV(condition_real_dV)
loss_dV_real = criterionGAN(pred_real_dV, True)
loss_dV = (loss_dV_fake + loss_dV_real) * 0.5
scaler.scale(loss_dV).backward(retain_graph=True)
scaler.step(optimizer_dV)
# 更新生成器
optimizer_g.zero_grad()
with autocast(enabled=True):
_, pred_fake_dI = net_dI(fake_out)
_, pred_fake_dV = net_dV(condition_fake_dV)
perception_real = net_vgg(source_clip)
perception_fake = net_vgg(fake_out)
perception_real_half = net_vgg(source_clip_half)
perception_fake_half = net_vgg(fake_out_half)
# -----------------感知损失计算----------------- #
loss_g_perception = sum([criterionL1(perception_fake[i], perception_real[i]) + criterionL1(perception_fake_half[i], perception_real_half[i]) for i in range(len(perception_real))]) / (len(perception_real) * 2)
# -----------------GAN损失计算----------------- #
loss_g_dI = criterionGAN(pred_fake_dI, True)
loss_g_dV = criterionGAN(pred_fake_dV, True)
# -----------------唇形同步损失计算----------------- #
fake_out_clip = torch.cat(torch.split(fake_out, opt.batch_size, dim=0), 1)
# 假定mouth_region_size定义了唇部区域的大小,并在train_data中已正确设置
mouth_region_size = opt.mouth_region_size
radius = mouth_region_size // 2
radius_1_4 = radius // 4
# 计算口部区域的起始和结束索引
start_x, start_y = radius, radius_1_4
end_x, end_y = start_x + mouth_region_size, start_y + mouth_region_size
fake_out_clip_mouth_origin_size = fake_out_clip[:, :, start_x:end_x, start_y:end_y]
# 将唇形部分调整到256x256,适应lip-sync网络
if mouth_region_size != 256:
fake_out_clip_mouth = F.interpolate(fake_out_clip_mouth_origin_size, size=(256, 256), mode='bilinear')
else:
fake_out_clip_mouth = fake_out_clip_mouth_origin_size
sync_score = net_lipsync(fake_out_clip_mouth, deep_speech_full)
loss_sync = criterionMSE(sync_score, torch.tensor(1.0).expand_as(sync_score).cuda())
# -----------------MSE损失计算部分----------------- #
loss_img = criterionMSE(fake_out, source_clip)
loss_g = (loss_img * opt.lambda_img + loss_g_perception * opt.lamb_perception + loss_g_dI * opt.lambda_g_dI + loss_g_dV * opt.lambda_g_dV + loss_sync * opt.lamb_syncnet_perception)
scaler.scale(loss_g).backward()
scaler.step(optimizer_g)
scaler.update()
# 记录到WandB
if iteration % opt.freq_wandb == 0:
log_to_wandb(source_clip, fake_out)
wandb.log({
"epoch": epoch,
"loss_dI": loss_dI.item(),
"loss_dV": loss_dV.item(),
"loss_g": loss_g.item(),
"loss_img": loss_img.item(),
"loss_g_perception": loss_g_perception.item(),
"loss_g_dI": loss_g_dI.item(),
"loss_g_dV": loss_g_dV.item(),
"loss_sync": loss_sync.item()
})
print(
f"Epoch {epoch}, Iteration {iteration}, "
f"loss_dI: {loss_dI.item()}, loss_dV: {loss_dV.item()}, "
f"loss_g: {loss_g.item()}, loss_img: {loss_img.item()}, "
f"loss_g_perception: {loss_g_perception.item()}, "
f"loss_g_dI: {loss_g_dI.item()}, loss_g_dV: {loss_g_dV.item()}, "
f"loss_sync: {loss_sync.item()}"
)
# 更新学习率
update_learning_rate(net_g_scheduler, optimizer_g)
update_learning_rate(net_dI_scheduler, optimizer_dI)
update_learning_rate(net_dV_scheduler, optimizer_dV)
if epoch % opt.checkpoint == 0:
save_checkpoint(epoch, opt, net_g, net_dI, net_dV, optimizer_g, optimizer_dI, optimizer_dV)
if epoch == 1:
config_dict = vars(opt)
config_out_path = os.path.join(opt.result_path, f'config_{time.strftime("%Y-%m-%d-%H-%M-%S")}.yaml')
save_config_to_yaml(config_dict, config_out_path)
# 检查点保存
def save_checkpoint(epoch, opt, net_g, net_dI, net_dV, optimizer_g, optimizer_dI, optimizer_dV):
model_out_path = os.path.join(opt.result_path, f'netG_model_epoch_{epoch}.pth')
states = {
'epoch': epoch + 1,
'state_dict': {
'net_g': net_g.state_dict(),
'net_dI': net_dI.state_dict(),
'net_dV': net_dV.state_dict()
},
'optimizer': {
'net_g': optimizer_g.state_dict(),
'net_dI': optimizer_dI.state_dict(),
'net_dV': optimizer_dV.state_dict()
}
}
torch.save(states, model_out_path)
print(f"Checkpoint saved to {model_out_path}")
# 主函数
if __name__ == "__main__":
# 解析配置文件路径
config_parser = argparse.ArgumentParser(description="Train lawdNet clip model", add_help=False)
config_parser.add_argument('--config_path', type=str, required=True, help="Path to the experiment configuration file.")
# 本次实验的名称 wandb
config_parser.add_argument('--name', type=str, required=True, help="Name of the experiment.")
args, remaining_argv = config_parser.parse_known_args()
# After extracting config_path, use it to load configurations
# import pdb; pdb.set_trace()
init_wandb(args.name)
opt, device = load_config_and_device(args)
os.makedirs(opt.result_path, exist_ok=True)
training_data_loader = load_training_data(opt)
net_g, net_dI, net_dV, net_vgg, net_lipsync = init_networks(opt)
optimizer_g, optimizer_dI, optimizer_dV = setup_optimizers(net_g, net_dI, net_dV)
load_pretrained_weights(net_g, opt)
criterionGAN, criterionL1, criterionMSE, criterionCosine = setup_criterion()
net_g_scheduler, net_dI_scheduler, net_dV_scheduler = setup_schedulers(optimizer_g, optimizer_dI, optimizer_dV)
train(opt, net_g, net_dI, net_dV, training_data_loader, optimizer_g, optimizer_dI, optimizer_dV, criterionGAN, criterionL1, criterionMSE, criterionCosine, net_g_scheduler, net_dI_scheduler, net_dV_scheduler)
wandb.finish()