-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathloss.py
93 lines (75 loc) · 3.4 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import scipy.ndimage as nd
class OhemCrossEntropy2d(nn.Module):
def __init__(self, ignore_label=255, thresh=0.7, min_kept=100000, factor=8):
super(OhemCrossEntropy2d, self).__init__()
self.ignore_label = ignore_label
self.thresh = float(thresh)
# self.min_kept_ratio = float(min_kept_ratio)
self.min_kept = int(min_kept)
self.factor = factor
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label)
def find_threshold(self, np_predict, np_target):
# downsample 1/8
factor = self.factor
predict = nd.zoom(np_predict, (1.0, 1.0, 1.0/factor, 1.0/factor), order=1)
target = nd.zoom(np_target, (1.0, 1.0/factor, 1.0/factor), order=0)
n, c, h, w = predict.shape
min_kept = self.min_kept // (factor*factor) #int(self.min_kept_ratio * n * h * w)
input_label = target.ravel().astype(np.int32)
input_prob = np.rollaxis(predict, 1).reshape((c, -1))
valid_flag = input_label != self.ignore_label
valid_inds = np.where(valid_flag)[0]
label = input_label[valid_flag]
num_valid = valid_flag.sum()
if min_kept >= num_valid:
threshold = 1.0
elif num_valid > 0:
prob = input_prob[:,valid_flag]
pred = prob[label, np.arange(len(label), dtype=np.int32)]
threshold = self.thresh
if min_kept > 0:
k_th = min(len(pred), min_kept)-1
new_array = np.partition(pred, k_th)
new_threshold = new_array[k_th]
if new_threshold > self.thresh:
threshold = new_threshold
return threshold
def generate_new_target(self, predict, target):
np_predict = predict.data.cpu().numpy()
np_target = target.data.cpu().numpy()
n, c, h, w = np_predict.shape
threshold = self.find_threshold(np_predict, np_target)
input_label = np_target.ravel().astype(np.int32)
input_prob = np.rollaxis(np_predict, 1).reshape((c, -1))
valid_flag = input_label != self.ignore_label
valid_inds = np.where(valid_flag)[0]
label = input_label[valid_flag]
num_valid = valid_flag.sum()
if num_valid > 0:
prob = input_prob[:,valid_flag]
pred = prob[label, np.arange(len(label), dtype=np.int32)]
kept_flag = pred <= threshold
valid_inds = valid_inds[kept_flag]
print('Labels: {} {}'.format(len(valid_inds), threshold))
label = input_label[valid_inds].copy()
input_label.fill(self.ignore_label)
input_label[valid_inds] = label
new_target = torch.from_numpy(input_label.reshape(target.size())).long().cuda(target.get_device())
return new_target
def forward(self, predict, target, weight=None):
"""
Args:
predict:(n, c, h, w)
target:(n, h, w)
weight (Tensor, optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size "nclasses"
"""
assert not target.requires_grad
input_prob = F.softmax(predict, 1)
target = self.generate_new_target(input_prob, target)
return self.criterion(predict, target)