forked from yilundu/cross_attention_renderer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss_functions.py
134 lines (108 loc) · 4.5 KB
/
loss_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import math
import numbers
import torch
from torch import nn
from torch.nn import functional as F
class GaussianSmoothing(nn.Module):
"""
Apply gaussian smoothing on a
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
in the input using a depthwise convolution.
Arguments:
channels (int, sequence): Number of channels of the input tensors. Output will
have this number of channels as well.
kernel_size (int, sequence): Size of the gaussian kernel.
sigma (float, sequence): Standard deviation of the gaussian kernel.
dim (int, optional): The number of dimensions of the data.
Default value is 2 (spatial).
"""
def __init__(self, channels, kernel_size, sigma, dim=2):
super(GaussianSmoothing, self).__init__()
if isinstance(kernel_size, numbers.Number):
kernel_size = [kernel_size] * dim
if isinstance(sigma, numbers.Number):
sigma = [sigma] * dim
# The gaussian kernel is the product of the
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
torch.exp(-((mgrid - mean) / std) ** 2 / 2)
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)
# Reshape to depthwise convolutional weight
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
self.groups = channels
if dim == 1:
self.conv = F.conv1d
elif dim == 2:
self.conv = F.conv2d
elif dim == 3:
self.conv = F.conv3d
else:
raise RuntimeError(
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
)
def forward(self, input):
"""
Apply gaussian filter to input.
Arguments:
input (torch.Tensor): Input to apply gaussian filter on.
Returns:
filtered (torch.Tensor): Filtered output.
"""
return self.conv(input, weight=self.weight, groups=self.groups)
def image_loss(model_out, gt, mask=None):
gt_rgb = gt['rgb']
gt_rgb[torch.isnan(gt_rgb)] = 0.0
rgb = model_out['rgb']
rgb[torch.isnan(rgb)] = 0.0
loss = torch.abs(gt_rgb - rgb).mean()
return loss
class LFLoss():
def __init__(self, l2_weight=1e-3, lpips=False, depth=False, reg_weight=1e2):
self.l2_weight = l2_weight
self.reg_weight = reg_weight
self.lpips = lpips
self.depth = depth
import lpips
loss_fn_alex = lpips.LPIPS(net='vgg').cuda()
self.loss_fn_alex = loss_fn_alex.cuda()
self.smooth = GaussianSmoothing(1, 15, 7).cuda()
self.avg = nn.AdaptiveAvgPool2d((2, 2))
self.upsample = nn.Upsample((32, 32), mode='bilinear')
def __call__(self, model_out, gt, model=None, val=False):
loss_dict = {}
loss_dict['img_loss'] = image_loss(model_out, gt)
if self.lpips:
gt_rgb = gt['rgb']
mask = gt['mask']
pred_rgb = model_out['rgb']
valid_mask = model_out['valid_mask']
offset = 32
gt_rgb = gt_rgb.reshape((-1, offset, offset, 3)).permute(0, 3, 1, 2)
pred_rgb = pred_rgb.reshape((-1, offset, offset, 3)).permute(0, 3, 1, 2)
if mask.size(0) == gt_rgb.size(0):
gt_rgb = gt_rgb * mask[:, None, None, None]
pred_rgb = pred_rgb * mask[:, None, None, None]
lpips_loss = self.loss_fn_alex(gt_rgb, pred_rgb)
# 0.2 for realestate
loss_dict['lpips_loss'] = 0.1 * lpips_loss
if self.depth and (not val):
depth_ray = model_out['depth_ray'][..., 0]
depth_ray = depth_ray.reshape((-1, 1, 32, 32))
depth_mean = depth_ray.mean(dim=-1).mean(dim=-1)[:, None, None]
depth_dist = self.l2_weight * torch.pow(depth_ray - depth_mean, 2).mean(dim=-1).mean(dim=-1).mean(dim=-1)
mask = gt['mask']
depth_loss = depth_dist * mask[:, None]
loss_dict['depth_loss'] = depth_loss.mean()
return loss_dict, {}