-
Notifications
You must be signed in to change notification settings - Fork 2
/
avatar_reenact.py
93 lines (77 loc) · 3.03 KB
/
avatar_reenact.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import yaml
import torch
import argparse
from tools.util import seed_everything
from tools.util import EasyDict
from train.callbacks import ModelCallbacks
from train.loader import Reenactor
from common import load_config
from common import construct_datasets
from common import load_identity_info
from common import construct_model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--model_name', choices=ModelCallbacks.keys(), required=True)
parser.add_argument('--dst_path', type=str, required=True)
parser.add_argument('--workspace', type=str, required=True)
parser.add_argument('--name', type=str, required=True)
parser.add_argument('--device', type=torch.device, default=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--bg_color', type=str, default='white')
opt = parser.parse_args()
seed_everything(opt.seed)
with open(opt.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
cfg = EasyDict(config)
# ---------------------------- override config ---------------------------- #
cfg = load_config(
opt.config,
overrides={
"name": opt.name,
"workspace": opt.workspace,
"bg_color": opt.bg_color
}
)
# --------------------------------- dataset --------------------------------- #
dataset, dataset_name = construct_datasets(
opt,
cfg,
mode = "train"
)
dst_dataset = dataset.train
dst_loader = torch.utils.data.DataLoader(
dst_dataset,
batch_size = 1,
shuffle = False,
collate_fn = dst_dataset.collate_fn,
num_workers = 4,
)
# ---------------------------- load identity info ---------------------------- #
identity_dict = load_identity_info(opt, cfg)
# ---------------------------------- model ----------------------------------- #
model = construct_model(
opt,
cfg.model,
0.0,
identity_dict = identity_dict
)
# ----------------------------- setup reenactor ------------------------------ #
reenactor = Reenactor(
opt.name,
cfg,
model,
opt.device,
workspace = opt.workspace,
use_checkpoint = 'latest'
)
# --------------------------------- execute ---------------------------------- #
dst_exp = dst_dataset.mean_expression
src_exp = identity_dict['canonical_expression']
delta_exp = src_exp - dst_exp
reenactor.reenacting(
dst_name=os.path.basename(cfg.dst_path),
dst_loader=dst_loader,
delta_exp=delta_exp
)