forked from CoinCheung/pytorch-loss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfo_nce_dist.py
127 lines (99 loc) · 4.86 KB
/
info_nce_dist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
'''
This can be used to train self-supervised learning models such as simclr. The merits of this implementation is that we can fully make use of negative samples even in distributed training mode. For example, if you use 8 gpus and each gpus has batch size of 32 x 2(two views of one image), you will use total negative samples of 32 x 2 x 8 - 2 = 510 negative samples to train your simclr model.
This implementation uses a "model distributed" method rather than "data distributed" method, so you should use this in distributed training mode, but not wrap this in pytorch nn.DistributedParallel module.
An example of usage is like this:
```python
# init distributed mode
dist.init_process_group(backend='nccl')
model = define_model()
model = nn.DistributedDataParallel(model) # model use distributed mode
crit = InfoNceDist(temper=0.1, margin=0.) # crit not use distributed mode
model.cuda()
crit.cuda()
params = list(model.parameters()) + list(crit.parameters())
optim = SGD(params, lr=1e-3)
for (ims_view1, ims_view2), ids in dataloader:
ims_view1, ims_view2 = ims_view1.cuda(), ims_view2.cuda()
ids = ids.cuda()
embs1 = model(ims_view1)
embs2 = model(ims_view2)
loss = crit(embs1, embs2)
optim.zero_grad()
loss.backward()
optim.step()
```
```
For details of simclr, please refer to the their paper: https://arxiv.org/pdf/2002.05709.pdf
Please note that this is different from the info-nce used in moco series, where a negative queue is given. This works in simclr's way.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.cuda.amp as amp
class InfoNceDist(nn.Module):
def __init__(self, temper=0.1, margin=0.):
super(InfoNceDist, self).__init__()
self.crit = nn.CrossEntropyLoss()
# we use margin, but not use s, because temperature works in same way as s
self.margin = margin
self.temp_factor = 1. / temper
def forward(self, embs1, embs2):
'''
embs1, embs2: n x c, one by one pairs
1 positive, 2n - 2 negative
distributed mode, no need to wrap with nn.DistributedParallel
'''
embs1 = F.normalize(embs1, dim=1)
embs2 = F.normalize(embs2, dim=1)
logits, labels = InfoNceFunction.apply(embs1, embs2, self.temp_factor, self.margin)
loss = self.crit(logits, labels.detach())
return loss
class InfoNceFunction(torch.autograd.Function):
@staticmethod
@amp.custom_fwd
def forward(ctx, embs1, embs2, temper_factor, margin):
assert embs1.size() == embs2.size()
N, C = embs1.size()
dtype = embs1.dtype
world_size = dist.get_world_size()
rank = dist.get_rank()
device = embs1.device
# gather for negative
all_embs1 = torch.zeros(
size=[N * world_size, C], dtype=dtype).cuda(device)
dist.all_gather(list(all_embs1.chunk(world_size, dim=0)), embs1)
all_embs2 = torch.zeros(
size=[N * world_size, C], dtype=dtype).cuda(device)
dist.all_gather(list(all_embs2.chunk(world_size, dim=0)), embs2)
all_embs = torch.cat([all_embs1, all_embs2], dim=0)
embs12 = torch.cat([embs1, embs2], dim=0)
logits = torch.einsum('ac,bc->ab', embs12, all_embs)
# mask off one sample to itself
inds1 = torch.arange(N * 2).cuda(device)
inds2 = torch.cat([
torch.arange(N) + rank * N,
torch.arange(N) + (rank + world_size) * N
], dim=0).cuda(device)
logits[inds1, inds2] = -10000. # such that exp should be 0
# label: 0~N should be N * [rank, rank + 1], N~(2N-1) should be N * [world_size * rank, world_size * (rank + 1)]
labels = inds2.view(2, -1).flip(dims=(0,)).reshape(-1)
# subtract margin, apply temperature
logits[inds1, labels] -= margin
logits *= temper_factor
ctx.vars = inds1, inds2, embs12, all_embs, temper_factor
return logits, labels
@staticmethod
@amp.custom_bwd
def backward(ctx, grad_logits, grad_label):
inds1, inds2, embs12, all_embs, temper_factor = ctx.vars
grad_logits = grad_logits * temper_factor
grad_logits[inds1, inds2] = 0
grad_embs12 = torch.einsum('ab,bc->ac', grad_logits, all_embs)
grad_all_embs = torch.einsum('ab,ac->bc', grad_logits, embs12)
rank = dist.get_rank()
world_size = dist.get_world_size()
N = int(all_embs.size(0) / (world_size * 2))
grad_embs1 = grad_embs12[:N] + grad_all_embs[rank * N : (rank + 1) * N]
grad_embs2 = grad_embs12[N:] + grad_all_embs[(rank + world_size) * N : (rank + world_size + 1) * N]
return grad_embs1, grad_embs2, None, None