-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathND_Crossentropy.py
220 lines (178 loc) · 7.21 KB
/
ND_Crossentropy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""
CrossentropyND and TopKLoss are from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/training/loss_functions/ND_Crossentropy.py
"""
import torch
import torch.nn.functional as F
from scipy.ndimage import distance_transform_edt
import numpy as np
class CrossentropyND(torch.nn.CrossEntropyLoss):
"""
Network has to have NO NONLINEARITY!
"""
def forward(self, inp, target):
target = target.long()
num_classes = inp.size()[1]
i0 = 1
i1 = 2
while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
inp = inp.transpose(i0, i1)
i0 += 1
i1 += 1
inp = inp.contiguous()
inp = inp.view(-1, num_classes)
target = target.view(-1,)
return super(CrossentropyND, self).forward(inp, target)
class TopKLoss(CrossentropyND):
"""
Network has to have NO LINEARITY!
"""
def __init__(self, weight=None, ignore_index=-100, k=10):
self.k = k
super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False)
def forward(self, inp, target):
target = target[:, 0].long()
res = super(TopKLoss, self).forward(inp, target)
num_voxels = np.prod(res.shape)
res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False)
return res.mean()
class WeightedCrossEntropyLoss(torch.nn.CrossEntropyLoss):
"""
Network has to have NO NONLINEARITY!
"""
def __init__(self, weight=None):
super(WeightedCrossEntropyLoss, self).__init__()
self.weight = weight
def forward(self, inp, target):
target = target.long()
num_classes = inp.size()[1]
i0 = 1
i1 = 2
while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
inp = inp.transpose(i0, i1)
i0 += 1
i1 += 1
inp = inp.contiguous()
inp = inp.view(-1, num_classes)
target = target.view(-1,)
wce_loss = torch.nn.CrossEntropyLoss(weight=self.weight)
return wce_loss(inp, target)
class WeightedCrossEntropyLossV2(torch.nn.Module):
"""
WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf
Network has to have NO LINEARITY!
copy from: https://github.com/wolny/pytorch-3dunet/blob/6e5a24b6438f8c631289c10638a17dea14d42051/unet3d/losses.py#L121
"""
def forward(self, net_output, gt):
# compute weight
# shp_x = net_output.shape
# shp_y = gt.shape
# print(shp_x, shp_y)
# with torch.no_grad():
# if len(shp_x) != len(shp_y):
# gt = gt.view((shp_y[0], 1, *shp_y[1:]))
# if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# # if this is the case then gt is probably already a one hot encoding
# y_onehot = gt
# else:
# gt = gt.long()
# y_onehot = torch.zeros(shp_x)
# if net_output.device.type == "cuda":
# y_onehot = y_onehot.cuda(net_output.device.index)
# y_onehot.scatter_(1, gt, 1)
# y_onehot = y_onehot.transpose(0,1).contiguous()
# class_weights = (torch.einsum("cbxyz->c", y_onehot).type(torch.float32) + 1e-10)/torch.numel(y_onehot)
# print('class_weights', class_weights)
# class_weights = class_weights.view(-1)
class_weights = torch.cuda.FloatTensor([0.2,0.8])
gt = gt.long()
num_classes = net_output.size()[1]
# class_weights = self._class_weights(inp)
i0 = 1
i1 = 2
while i1 < len(net_output.shape): # this is ugly but torch only allows to transpose two axes at once
net_output = net_output.transpose(i0, i1)
i0 += 1
i1 += 1
net_output = net_output.contiguous()
net_output = net_output.view(-1, num_classes) #shape=(vox_num, class_num)
gt = gt.view(-1,)
# print('*'*20)
return F.cross_entropy(net_output, gt) # , weight=class_weights
# @staticmethod
# def _class_weights(input):
# # normalize the input first
# input = F.softmax(input, _stacklevel=5)
# flattened = flatten(input)
# nominator = (1. - flattened).sum(-1)
# denominator = flattened.sum(-1)
# class_weights = Variable(nominator / denominator, requires_grad=False)
# return class_weights
def flatten(tensor):
"""Flattens a given tensor such that the channel axis is first.
The shapes are transformed as follows:
(N, C, D, H, W) -> (C, N * D * H * W)
"""
C = tensor.size(1)
# new axis order
axis_order = (1, 0) + tuple(range(2, tensor.dim()))
# Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
transposed = tensor.permute(axis_order)
# Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
transposed = transposed.contiguous()
return transposed.view(C, -1)
def compute_edts_forPenalizedLoss(GT):
"""
GT.shape = (batch_size, x,y,z)
only for binary segmentation
"""
GT = np.squeeze(GT)
res = np.zeros(GT.shape)
for i in range(GT.shape[0]):
posmask = GT[i]
negmask = ~posmask
pos_edt = distance_transform_edt(posmask)
pos_edt = (np.max(pos_edt)-pos_edt)*posmask
neg_edt = distance_transform_edt(negmask)
neg_edt = (np.max(neg_edt)-neg_edt)*negmask
res[i] = pos_edt/np.max(pos_edt) + neg_edt/np.max(neg_edt)
return res
class DisPenalizedCE(torch.nn.Module):
"""
Only for binary 3D segmentation
Network has to have NO NONLINEARITY!
"""
def forward(self, inp, target):
# print(inp.shape, target.shape) # (batch, 2, xyz), (batch, 2, xyz)
# compute distance map of ground truth
with torch.no_grad():
dist = compute_edts_forPenalizedLoss(target.cpu().numpy()>0.5) + 1.0
dist = torch.from_numpy(dist)
if dist.device != inp.device:
dist = dist.to(inp.device).type(torch.float32)
dist = dist.view(-1,)
target = target.long()
num_classes = inp.size()[1]
i0 = 1
i1 = 2
while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
inp = inp.transpose(i0, i1)
i0 += 1
i1 += 1
inp = inp.contiguous()
inp = inp.view(-1, num_classes)
log_sm = torch.nn.LogSoftmax(dim=1)
inp_logs = log_sm(inp)
target = target.view(-1,)
# loss = nll_loss(inp_logs, target)
loss = -inp_logs[range(target.shape[0]), target]
# print(loss.type(), dist.type())
weighted_loss = loss*dist
return loss.mean()
def nll_loss(input, target):
"""
customized nll loss
source: https://medium.com/@zhang_yang/understanding-cross-entropy-
implementation-in-pytorch-softmax-log-softmax-nll-cross-entropy-416a2b200e34
"""
loss = -input[range(target.shape[0]), target]
return loss.mean()