-
Notifications
You must be signed in to change notification settings - Fork 0
/
msssim_l1.py
93 lines (77 loc) · 3.32 KB
/
msssim_l1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn as nn
import torch.nn.functional as F
class MS_SSIM_L1_LOSS(nn.Module):
"""
Created on Thu Dec 3 00:28:15 2020
Modified on Wed Feb 38 20:19:10 2024
@orig-author for 3C: Yunpeng Li, Tianjin University
@modified-author for 1C: William Jongwon Han, Carnegie Mellon University
"""
# Have to use cuda, otherwise the speed is too slow.
def __init__(self, gaussian_sigmas=[0.5, 1.0, 2.0, 4.0, 8.0],
data_range = 1.0,
K=(0.01, 0.03),
alpha=0.84,
compensation=200.0,
device=None,):
super(MS_SSIM_L1_LOSS, self).__init__()
self.DR = data_range
self.C1 = (K[0] * data_range) ** 2
self.C2 = (K[1] * data_range) ** 2
self.pad = int(2 * gaussian_sigmas[-1])
self.alpha = alpha
self.compensation=compensation
filter_size = int(4 * gaussian_sigmas[-1] + 1)
g_masks = torch.zeros((1*len(gaussian_sigmas), 1, filter_size, filter_size))
for idx, sigma in enumerate(gaussian_sigmas):
# r0,g0,b0
g_masks[1*idx+0, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
self.g_masks = g_masks.to(device)
def _fspecial_gauss_1d(self, size, sigma):
"""Create 1-D gauss kernel
Args:
size (int): the size of gauss kernel
sigma (float): sigma of normal distribution
Returns:
torch.Tensor: 1D kernel (size)
"""
coords = torch.arange(size).to(dtype=torch.float)
coords -= size // 2
g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
g /= g.sum()
return g.reshape(-1)
def _fspecial_gauss_2d(self, size, sigma):
"""Create 2-D gauss kernel
Args:
size (int): the size of gauss kernel
sigma (float): sigma of normal distribution
Returns:
torch.Tensor: 2D kernel (size x size)
"""
gaussian_vec = self._fspecial_gauss_1d(size, sigma)
return torch.outer(gaussian_vec, gaussian_vec)
def forward(self, x, y, mask = None):
b, c, h, w = x.shape
mux = F.conv2d(x, self.g_masks, groups=1, padding=self.pad)
muy = F.conv2d(y, self.g_masks, groups=1, padding=self.pad)
mux2 = mux * mux
muy2 = muy * muy
muxy = mux * muy
sigmax2 = F.conv2d(x * x, self.g_masks, groups=1, padding=self.pad) - mux2
sigmay2 = F.conv2d(y * y, self.g_masks, groups=1, padding=self.pad) - muy2
sigmaxy = F.conv2d(x * y, self.g_masks, groups=1, padding=self.pad) - muxy
# l(j), cs(j) in MS-SSIM
l = (2 * muxy + self.C1) / (mux2 + muy2 + self.C1) # [B, 5, H, W]
cs = (2 * sigmaxy + self.C2) / (sigmax2 + sigmay2 + self.C2)
lM = l.prod(dim=1)
PIcs = cs.prod(dim=1)
loss_ms_ssim = 1 - lM*PIcs # [B, H, W]
loss_l1 = F.l1_loss(x, y, reduction='none') # [B, 1, H, W]
# average l1 loss in 1 channel
gaussian_l1 = F.conv2d(loss_l1, self.g_masks, groups = 1, padding = self.pad).mean(1) # [B, H, W]
loss_mix = self.alpha * loss_ms_ssim + (1 - self.alpha) * gaussian_l1 / self.DR
loss_mix = self.compensation*loss_mix
if mask != None:
loss_mix = loss_mix * mask
return loss_mix.mean()