forked from Syliz517/CLIP-ReID
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake_loss.py
89 lines (71 loc) · 3.57 KB
/
make_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# encoding: utf-8
"""
@author: liaoxingyu
@contact: [email protected]
"""
import torch.nn.functional as F
from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy
from .triplet_loss import TripletLoss
from .center_loss import CenterLoss
def make_loss(cfg, num_classes): # modified by gu
sampler = cfg.DATALOADER.SAMPLER
feat_dim = 2048
center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss
if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE:
if cfg.MODEL.NO_MARGIN:
triplet = TripletLoss()
print("using soft triplet loss for training")
else:
triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss
print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN))
else:
print('expected METRIC_LOSS_TYPE should be triplet'
'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
if cfg.MODEL.IF_LABELSMOOTH == 'on':
xent = CrossEntropyLabelSmooth(num_classes=num_classes)
print("label smooth on, numclasses:", num_classes)
if sampler == 'softmax':
def loss_func(score, feat, target):
return F.cross_entropy(score, target)
elif cfg.DATALOADER.SAMPLER == 'softmax_triplet':
def loss_func(score, feat, target, target_cam, i2tscore = None):
if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet':
if cfg.MODEL.IF_LABELSMOOTH == 'on':
if isinstance(score, list):
ID_LOSS = [xent(scor, target) for scor in score[0:]]
ID_LOSS = sum(ID_LOSS)
else:
ID_LOSS = xent(score, target)
if isinstance(feat, list):
TRI_LOSS = [triplet(feats, target)[0] for feats in feat[0:]]
TRI_LOSS = sum(TRI_LOSS)
else:
TRI_LOSS = triplet(feat, target)[0]
loss = cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS
if i2tscore != None:
I2TLOSS = xent(i2tscore, target)
loss = cfg.MODEL.I2T_LOSS_WEIGHT * I2TLOSS + loss
return loss
else:
if isinstance(score, list):
ID_LOSS = [F.cross_entropy(scor, target) for scor in score[0:]]
ID_LOSS = sum(ID_LOSS)
else:
ID_LOSS = F.cross_entropy(score, target)
if isinstance(feat, list):
TRI_LOSS = [triplet(feats, target)[0] for feats in feat[0:]]
TRI_LOSS = sum(TRI_LOSS)
else:
TRI_LOSS = triplet(feat, target)[0]
loss = cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS
if i2tscore != None:
I2TLOSS = F.cross_entropy(i2tscore, target)
loss = cfg.MODEL.I2T_LOSS_WEIGHT * I2TLOSS + loss
return loss
else:
print('expected METRIC_LOSS_TYPE should be triplet'
'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE))
else:
print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center'
'but got {}'.format(cfg.DATALOADER.SAMPLER))
return loss_func, center_criterion