-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathLoss.py
120 lines (95 loc) · 4.16 KB
/
Loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn.functional as F
class NCC(torch.nn.Module):
"""
NCC with cumulative sum implementation for acceleration. local (over window) normalized cross correlation.
"""
def __init__(self, win=21, eps=1e-5):
super(NCC, self).__init__()
self.eps = eps
self.win = win
self.win_raw = win
def window_sum_cs3D(self, I, win_size):
half_win = int(win_size / 2)
pad = [half_win + 1, half_win] * 3
I_padded = F.pad(I, pad=pad, mode='constant', value=0) # [x+pad, y+pad, z+pad]
# Run the cumulative sum across all 3 dimensions
I_cs_x = torch.cumsum(I_padded, dim=2)
I_cs_xy = torch.cumsum(I_cs_x, dim=3)
I_cs_xyz = torch.cumsum(I_cs_xy, dim=4)
x, y, z = I.shape[2:]
# Use subtraction to calculate the window sum
I_win = I_cs_xyz[:, :, win_size:, win_size:, win_size:] \
- I_cs_xyz[:, :, win_size:, win_size:, :z] \
- I_cs_xyz[:, :, win_size:, :y, win_size:] \
- I_cs_xyz[:, :, :x, win_size:, win_size:] \
+ I_cs_xyz[:, :, win_size:, :y, :z] \
+ I_cs_xyz[:, :, :x, win_size:, :z] \
+ I_cs_xyz[:, :, :x, :y, win_size:] \
- I_cs_xyz[:, :, :x, :y, :z]
return I_win
def forward(self, I, J):
# compute CC squares
I = I.double()
J = J.double()
I2 = I * I
J2 = J * J
IJ = I * J
# compute local sums via cumsum trick
I_sum_cs = self.window_sum_cs3D(I, self.win)
J_sum_cs = self.window_sum_cs3D(J, self.win)
I2_sum_cs = self.window_sum_cs3D(I2, self.win)
J2_sum_cs = self.window_sum_cs3D(J2, self.win)
IJ_sum_cs = self.window_sum_cs3D(IJ, self.win)
win_size_cs = (self.win * 1.) ** 3
u_I_cs = I_sum_cs / win_size_cs
u_J_cs = J_sum_cs / win_size_cs
cross_cs = IJ_sum_cs - u_J_cs * I_sum_cs - u_I_cs * J_sum_cs + u_I_cs * u_J_cs * win_size_cs
I_var_cs = I2_sum_cs - 2 * u_I_cs * I_sum_cs + u_I_cs * u_I_cs * win_size_cs
J_var_cs = J2_sum_cs - 2 * u_J_cs * J_sum_cs + u_J_cs * u_J_cs * win_size_cs
cc_cs = cross_cs * cross_cs / (I_var_cs * J_var_cs + self.eps)
cc2 = cc_cs # cross correlation squared
# return negative cc.
return 1. - torch.mean(cc2).float()
def JacboianDet(J):
# if J.size(-1) != 3:
# J = J.permute(0, 2, 3, 4, 1)
J = J
# J = J / 2.
scale_factor = torch.tensor([J.size(1), J.size(2), J.size(3)]).to(J).view(1, 1, 1, 1, 3) * 1.
# import pdb;pdb.set_trace()
J = J * scale_factor
dy = J[:, 1:, :-1, :-1, :] - J[:, :-1, :-1, :-1, :]
dx = J[:, :-1, 1:, :-1, :] - J[:, :-1, :-1, :-1, :]
dz = J[:, :-1, :-1, 1:, :] - J[:, :-1, :-1, :-1, :]
# dx[:,:,:,:,0] += 1
# dy[:,:,:,:,1] += 1
# dz[:,:,:,:,2] += 1
# import pdb;pdb.set_trace()
Jdet0 = dx[:, :, :, :, 0] * (dy[:, :, :, :, 1] * dz[:, :, :, :, 2] - dy[:, :, :, :, 2] * dz[:, :, :, :, 1])
Jdet1 = dx[:, :, :, :, 1] * (dy[:, :, :, :, 0] * dz[:, :, :, :, 2] - dy[:, :, :, :, 2] * dz[:, :, :, :, 0])
Jdet2 = dx[:, :, :, :, 2] * (dy[:, :, :, :, 0] * dz[:, :, :, :, 1] - dy[:, :, :, :, 1] * dz[:, :, :, :, 0])
Jdet = Jdet0 - Jdet1 + Jdet2
return Jdet
def neg_Jdet(J):
Jdet = JacboianDet(J)
# import pdb;pdb.set_trace()
Jdet = (Jdet<0).flatten()
return Jdet.sum()/len(Jdet)
def neg_Jdet_loss(J):
Jdet = JacboianDet(J)
# import pdb;pdb.set_trace()
# neg_Jdet = -1.0 * (Jdet - 0.5)
neg_Jdet = -1.0 * (Jdet)
selected_neg_Jdet = F.relu(neg_Jdet)
return torch.mean(selected_neg_Jdet)
def smoothloss_loss(df):
return (((df[:, :, 1:, :, :] - df[:, :, :-1, :, :]) ** 2).mean() + \
((df[:, :, :, 1:, :] - df[:, :, :, :-1, :]) ** 2).mean() + \
((df[:, :, :, :, 1:] - df[:, :, :, :, :-1]) ** 2).mean())
def magnitude_loss(all_v):
all_v_x_2 = all_v[:, :, 0, :, :, :] * all_v[:, :, 0, :, :, :]
all_v_y_2 = all_v[:, :, 1, :, :, :] * all_v[:, :, 1, :, :, :]
all_v_z_2 = all_v[:, :, 2, :, :, :] * all_v[:, :, 2, :, :, :]
all_v_magnitude = torch.mean(all_v_x_2 + all_v_y_2 + all_v_z_2)
return all_v_magnitude