forked from BestJuly/IIC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLinearAverage.py
58 lines (44 loc) · 1.74 KB
/
LinearAverage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from torch.autograd import Function
from torch import nn
import math
class LinearAverageOp(Function):
@staticmethod
def forward(self, x, y, memory, params):
T = params[0].item()
batchSize = x.size(0)
# inner product
out = torch.mm(x.data, memory.t())
out.div_(T) # batchSize * N
self.save_for_backward(x, memory, y, params)
return out
@staticmethod
def backward(self, gradOutput):
x, memory, y, params = self.saved_tensors
batchSize = gradOutput.size(0)
T = params[0].item()
momentum = params[1].item()
# add temperature
gradOutput.data.div_(T)
# gradient of linear
gradInput = torch.mm(gradOutput.data, memory)
gradInput.resize_as_(x)
# update the non-parametric data
weight_pos = memory.index_select(0, y.data.view(-1)).resize_as_(x)
weight_pos.mul_(momentum)
weight_pos.add_(torch.mul(x.data, 1-momentum))
w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5)
updated_weight = weight_pos.div(w_norm)
memory.index_copy_(0, y, updated_weight)
return gradInput, None, None, None
class LinearAverage(nn.Module):
def __init__(self, inputSize, outputSize, T=0.07, momentum=0.5):
super(LinearAverage, self).__init__()
stdv = 1 / math.sqrt(inputSize)
self.nLem = outputSize
self.register_buffer('params',torch.tensor([T, momentum]));
stdv = 1. / math.sqrt(inputSize/3)
self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2*stdv).add_(-stdv))
def forward(self, x, y):
out = LinearAverageOp.apply(x, y, self.memory, self.params)
return out