-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
248 lines (192 loc) · 8.37 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class OriTripletLoss(nn.Module):
"""Triplet loss with hard positive/negative mining.
Reference:
Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
Args:
- margin (float): margin for triplet.
"""
def __init__(self, batch_size, margin=0.3):
super(OriTripletLoss, self).__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
def forward(self, inputs, targets):
"""
Args:
- inputs: feature matrix with shape (batch_size, feat_dim)
- targets: ground truth labels with shape (num_classes)
"""
n = inputs.size(0)
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
dist = dist + dist.t()
dist.addmm_(1, -2, inputs, inputs.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
# For each anchor, find the hardest positive and negative
mask = targets.expand(n, n).eq(targets.expand(n, n).t())
dist_ap, dist_an = [], []
for i in range(n):
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
dist_ap = torch.cat(dist_ap)
dist_an = torch.cat(dist_an)
# Compute ranking hinge loss
y = torch.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y)
# compute accuracy
correct = torch.ge(dist_an, dist_ap).sum().item()
return loss, correct
# Adaptive weights
def softmax_weights(dist, mask):
max_v = torch.max(dist * mask, dim=1, keepdim=True)[0]
diff = dist - max_v
Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero
W = torch.exp(diff) * mask / Z
return W
def normalize(x, axis=-1):
"""Normalizing to unit length along the specified dimension.
Args:
x: pytorch Variable
Returns:
x: pytorch Variable, same shape as input
"""
x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
return x
class TripletLoss_WRT(nn.Module):
"""Weighted Regularized Triplet'."""
def __init__(self):
super(TripletLoss_WRT, self).__init__()
self.ranking_loss = nn.SoftMarginLoss()
def forward(self, inputs, targets, normalize_feature=False):
if normalize_feature:
inputs = normalize(inputs, axis=-1)
dist_mat = pdist_torch(inputs, inputs)
N = dist_mat.size(0)
# shape [N, N]
is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float()
is_neg = targets.expand(N, N).ne(targets.expand(N, N).t()).float()
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
dist_ap = dist_mat * is_pos
dist_an = dist_mat * is_neg
weights_ap = softmax_weights(dist_ap, is_pos)
weights_an = softmax_weights(-dist_an, is_neg)
furthest_positive = torch.sum(dist_ap * weights_ap, dim=1)
closest_negative = torch.sum(dist_an * weights_an, dim=1)
y = furthest_positive.new().resize_as_(furthest_positive).fill_(1)
loss = self.ranking_loss(closest_negative - furthest_positive, y)
# compute accuracy
correct = torch.ge(closest_negative, furthest_positive).sum().item()
return loss, correct
class TripletLoss_ADP(nn.Module):
"""Weighted Regularized Triplet'."""
def __init__(self, alpha=1, gamma=1, square=0):
super(TripletLoss_ADP, self).__init__()
self.ranking_loss = nn.SoftMarginLoss()
self.alpha = alpha
self.gamma = gamma
self.square = square
def forward(self, inputs, targets, normalize_feature=False):
if normalize_feature:
inputs = normalize(inputs, axis=-1)
dist_mat = pdist_torch(inputs, inputs)
N = dist_mat.size(0)
# shape [N, N]
is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float()
is_neg = targets.expand(N, N).ne(targets.expand(N, N).t()).float()
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N, 1]
dist_ap = dist_mat * is_pos
dist_an = dist_mat * is_neg
weights_ap = softmax_weights(dist_ap * self.alpha, is_pos)
weights_an = softmax_weights(-dist_an * self.alpha, is_neg)
furthest_positive = torch.sum(dist_ap * weights_ap, dim=1)
closest_negative = torch.sum(dist_an * weights_an, dim=1)
# ranking_loss = nn.SoftMarginLoss(reduction = 'none')
# loss1 = ranking_loss(closest_negative - furthest_positive, y)
# squared difference
if self.square == 0:
y = furthest_positive.new().resize_as_(furthest_positive).fill_(1)
loss = self.ranking_loss(self.gamma * (closest_negative - furthest_positive), y)
else:
diff_pow = torch.pow(furthest_positive - closest_negative, 2) * self.gamma
diff_pow = torch.clamp_max(diff_pow, max=44)
# Compute ranking hinge loss
y1 = (furthest_positive > closest_negative).float()
y2 = y1 - 1
y = -(y1 + y2)
loss = self.ranking_loss(diff_pow, y)
# loss = self.ranking_loss(self.gamma*(closest_negative - furthest_positive), y)
# compute accuracy
correct = torch.ge(closest_negative, furthest_positive).sum().item()
return loss, correct
class KLDivLoss(nn.Module):
def __init__(self):
super(KLDivLoss, self).__init__()
def forward(self, pred, label):
# pred: 2D matrix (batch_size, num_classes)
# label: 1D vector indicating class number
T = 3
predict = F.log_softmax(pred / T, dim=1)
target_data = F.softmax(label / T, dim=1)
target_data = target_data + 10 ** (-7)
target = Variable(target_data.data.cuda(), requires_grad=False)
loss = T * T * ((target * (target.log() - predict)).sum(1).sum() / target.size()[0])
return loss
def pdist_torch(emb1, emb2):
'''
compute the eucilidean distance matrix between embeddings1 and embeddings2
using gpu
'''
m, n = emb1.shape[0], emb2.shape[0]
emb1_pow = torch.pow(emb1, 2).sum(dim=1, keepdim=True).expand(m, n)
emb2_pow = torch.pow(emb2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
dist_mtx = emb1_pow + emb2_pow
dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t())
# dist_mtx = dist_mtx.clamp(min = 1e-12)
dist_mtx = dist_mtx.clamp(min=1e-12).sqrt()
return dist_mtx
def pdist_np(emb1, emb2):
'''
compute the eucilidean distance matrix between embeddings1 and embeddings2
using cpu
'''
m, n = emb1.shape[0], emb2.shape[0]
emb1_pow = np.square(emb1).sum(axis=1)[..., np.newaxis]
emb2_pow = np.square(emb2).sum(axis=1)[np.newaxis, ...]
dist_mtx = -2 * np.matmul(emb1, emb2.T) + emb1_pow + emb2_pow
# dist_mtx = np.sqrt(dist_mtx.clip(min = 1e-12))
return dist_mtx
class MRIC(nn.Module):
def __init__(self):
super(MRIC, self).__init__()
self.adaptive = nn.Softmax(dim=0)
def Adaptive_Identity(self, x):
sim = torch.sum(torch.mm(x, x.T), dim=1)
return self.adaptive(sim).view(4, -1)
def forward(self, x1, x2):
x1 = F.normalize(x1)
x2 = F.normalize(x2)
b = x2.shape[0] // 4
listx1 = []
listx2 = []
for i in range(b):
listx1.append(
torch.sum(self.Adaptive_Identity(x1[i * 4:(i + 1) * 4]) * x1[i * 4:(i + 1) * 4], dim=0).view(1, -1))
listx2.append(
torch.sum(self.Adaptive_Identity(x2[i * 4:(i + 1) * 4]) * x2[i * 4:(i + 1) * 4], dim=0).view(1, -1))
x1 = torch.cat(listx1, dim=0)
x2 = torch.cat(listx2, dim=0)
x1 = F.normalize(x1)
x2 = F.normalize(x2)
center_loss = ((x1 - x2).norm(dim=1, keepdim=True)).mean()
sim = torch.mm(x1, x2.T)
labels = torch.arange(b).cuda()
loss_t = F.cross_entropy(sim, labels)
loss_i = F.cross_entropy(sim.T, labels)
return (loss_t + loss_i) + center_loss, x1, x2