-
Notifications
You must be signed in to change notification settings - Fork 0
/
LookupTable.lua
156 lines (135 loc) · 4.45 KB
/
LookupTable.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
local LookupTable, parent = torch.class('nn.LookupTable', 'nn.Module')
LookupTable.__version = 3
function LookupTable:__init(nIndex, ...)
parent.__init(self)
local arg = {...}
if select('#', ...) == 1 and type(arg[1]) ~= "number" then
local size = arg[1]
self.size = torch.LongStorage(#size + 1)
for i=1,#size do
self.size[i+1] = size[i]
end
else
self.size = torch.LongStorage(select('#', ...)+1)
for i=1,select('#',...) do
self.size[i+1] = arg[i]
end
end
self.size[1] = nIndex
local batchSize = torch.LongTensor(#self.size + 1)
batchSize:narrow(1, 2,#self.size):copy(torch.LongTensor(self.size))
batchSize[1] = 1
self.batchSize = batchSize:storage()
self.weight = torch.Tensor(self.size)
self.gradWeight = torch.Tensor(self.size):zero()
self.inputs = {}
self.accUpdate = false
self.nBackward = 0
self:reset()
end
function LookupTable:accUpdateOnly()
self.accUpdate = true
self.gradWeight = nil
end
function LookupTable:reset(stdv)
stdv = stdv or 1
if nn.oldSeed then
self.weight:apply(function()
return torch.normal(0, stdv)
end)
else
self.weight:normal(0, stdv)
end
end
function LookupTable:updateOutput(input)
-- make sure input is a contiguous torch.LongTensor
if (not input:isContiguous()) or torch.type(input) ~= 'torch.LongTensor' then
self._indices = self._indices or torch.LongTensor()
self._indices:resize(input:size()):copy(input)
input = self._indices
end
if input:dim() == 1 then
local nIndex = input:size(1)
self.size[1] = nIndex
self.output:index(self.weight, 1, input)
elseif input:dim() == 2 then
local nExample = input:size(1)
local nIndex = input:size(2)
self.batchSize[1] = nExample
self.batchSize[2] = nIndex
self._inputView = self._inputView or torch.LongTensor()
self._inputView:view(input, -1)
self.output:index(self.weight, 1, self._inputView)
self.output = self.output:view(nExample, nIndex, self.size[2])
end
return self.output
end
function LookupTable:zeroGradParameters()
if not self.accUpdate then
for k,_ in pairs(self.inputs) do
self.gradWeight:select(1, k):zero()
end
end
self.inputs = {}
self.nBackward = 0
end
function LookupTable:accGradParameters(input, gradOutput, scale)
scale = scale or 1
if input:dim() == 1 then
self.nBackward = self.nBackward + 1
for i=1,input:size(1) do
local k = input[i]
self.inputs[k] = (self.inputs[k] or 0) + 1
self.gradWeight:select(1, k):add(scale, gradOutput:select(1, i))
end
elseif input:dim() == 2 then
self.nBackward = self.nBackward + input:size(1)
for i=1,input:size(1) do
local input = input:select(1, i)
local gradOutput = gradOutput:select(1, i)
for j=1,input:size(1) do
local k = input[j]
self.inputs[k] = (self.inputs[k] or 0) + 1
self.gradWeight:select(1, k):add(scale, gradOutput:select(1, j))
end
end
end
end
function LookupTable:accUpdateGradParameters(input, gradOutput, lr)
if input:dim() == 1 then
for i=1,input:size(1) do
local k = input[i]
local kscale = self:scaleUpdateByKey(k)
self.weight:select(1, input[i]):add(-lr*kscale, gradOutput:select(1, i))
end
elseif input:dim() == 2 then
for i=1,input:size(1) do
local input = input:select(1, i)
local gradOutput = gradOutput:select(1, i)
for j=1,input:size(1) do
local k = input[j]
local kscale = self:scaleUpdateByKey(k)
self.weight:select(1, k):add(-lr*kscale, gradOutput:select(1, j))
end
end
end
end
function LookupTable:updateParameters(learningRate)
assert(not self.accUpdate, "use accUpdateGradParameters instead")
for k,nBackward in pairs(self.inputs) do
local kscale = self:scaleUpdateByKey(k)
self.weight:select(1, k):add(-learningRate*kscale, self.gradWeight:select(1, k))
end
end
function LookupTable:type(type)
self._indices = nil
self._inputView = nil
parent.type(self, type)
end
-- scale the update for each key
function LookupTable:scaleUpdateByKey(inputKey)
-- default is to perform no key-based scalling
return 1
end
-- we do not need to accumulate parameters when sharing
LookupTable.sharedAccUpdateGradParameters = LookupTable.accUpdateGradParameters