forked from ZPdesu/Barbershop
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathblend_loss.py
62 lines (49 loc) · 1.79 KB
/
blend_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import PIL
import os
from losses import masked_lpips
class BlendLossBuilder(torch.nn.Module):
def __init__(self, opt):
super(BlendLossBuilder, self).__init__()
self.opt = opt
self.parsed_loss = [[1.0, 'face'], [1.0, 'hair']]
if opt.device == 'cuda':
use_gpu = True
else:
use_gpu = False
self.face_percept = masked_lpips.PerceptualLoss(
model="net-lin", net="vgg", vgg_blocks=['1', '2', '3'], use_gpu=use_gpu
)
self.face_percept.eval()
self.hair_percept = masked_lpips.PerceptualLoss(
model="net-lin", net="vgg", vgg_blocks=['1', '2', '3'], use_gpu=use_gpu
)
self.hair_percept.eval()
def _loss_face_percept(self, gen_im, ref_im, mask, **kwargs):
return self.face_percept(gen_im, ref_im, mask=mask)
def _loss_hair_percept(self, gen_im, ref_im, mask, **kwargs):
return self.hair_percept(gen_im, ref_im, mask=mask)
def forward(self, gen_im, im_1, im_3, mask_face, mask_hair):
loss = 0
loss_fun_dict = {
'face': self._loss_face_percept,
'hair': self._loss_hair_percept,
}
losses = {}
for weight, loss_type in self.parsed_loss:
if loss_type == 'face':
var_dict = {
'gen_im': gen_im,
'ref_im': im_1,
'mask': mask_face
}
elif loss_type == 'hair':
var_dict = {
'gen_im': gen_im,
'ref_im': im_3,
'mask': mask_hair
}
tmp_loss = loss_fun_dict[loss_type](**var_dict)
losses[loss_type] = tmp_loss
loss += weight*tmp_loss
return loss, losses