@@ -9,7 +9,8 @@ function Euclidean:__init(inputSize,outputSize)
9
9
-- state
10
10
self .gradInput :resize (inputSize )
11
11
self .output :resize (outputSize )
12
- self .temp = torch .Tensor (inputSize )
12
+
13
+ self .fastBackward = true
13
14
14
15
self :reset ()
15
16
end
@@ -31,37 +32,159 @@ function Euclidean:reset(stdv)
31
32
end
32
33
end
33
34
35
+ local function view (res , src , ...)
36
+ local args = {... }
37
+ if src :isContiguous () then
38
+ res :view (src , unpack (args ))
39
+ else
40
+ res :reshape (src , unpack (args ))
41
+ end
42
+ end
43
+
34
44
function Euclidean :updateOutput (input )
35
- self .output :zero ()
36
- for o = 1 ,self .weight :size (2 ) do
37
- self .output [o ] = input :dist (self .weight :select (2 ,o ))
45
+ -- lazy initialize buffers
46
+ self ._input = self ._input or input .new ()
47
+ self ._weight = self ._weight or self .weight .new ()
48
+ self ._expand = self ._expand or self .output .new ()
49
+ self ._expand2 = self ._expand2 or self .output .new ()
50
+ self ._repeat = self ._repeat or self .output .new ()
51
+ self ._repeat2 = self ._repeat2 or self .output .new ()
52
+
53
+ local inputSize , outputSize = self .weight :size (1 ), self .weight :size (2 )
54
+
55
+ -- y_j = || w_j - x || = || x - w_j ||
56
+ if input :dim () == 1 then
57
+ view (self ._input , input , inputSize , 1 )
58
+ self ._expand :expandAs (self ._input , self .weight )
59
+ self ._repeat :resizeAs (self ._expand ):copy (self ._expand )
60
+ self ._repeat :add (- 1 , self .weight )
61
+ self .output :norm (self ._repeat , 2 , 1 )
62
+ self .output :resize (outputSize )
63
+ elseif input :dim () == 2 then
64
+ local batchSize = input :size (1 )
65
+
66
+ view (self ._input , input , batchSize , inputSize , 1 )
67
+ self ._expand :expand (self ._input , batchSize , inputSize , outputSize )
68
+ -- make the expanded tensor contiguous (requires lots of memory)
69
+ self ._repeat :resizeAs (self ._expand ):copy (self ._expand )
70
+
71
+ self ._weight :view (self .weight , 1 , inputSize , outputSize )
72
+ self ._expand2 :expandAs (self ._weight , self ._repeat )
73
+
74
+ if torch .type (input ) == ' torch.CudaTensor' then
75
+ -- requires lots of memory, but minimizes cudaMallocs and loops
76
+ self ._repeat2 :resizeAs (self ._expand2 ):copy (self ._expand2 )
77
+ self ._repeat :add (- 1 , self ._repeat2 )
78
+ else
79
+ self ._repeat :add (- 1 , self ._expand2 )
80
+ end
81
+
82
+ self .output :norm (self ._repeat , 2 , 2 )
83
+ self .output :resize (batchSize , outputSize )
84
+ else
85
+ error " 1D or 2D input expected"
38
86
end
87
+
39
88
return self .output
40
89
end
41
90
42
91
function Euclidean :updateGradInput (input , gradOutput )
43
- self :updateOutput (input )
44
- if self .gradInput then
45
- self .gradInput :zero ()
46
- for o = 1 ,self .weight :size (2 ) do
47
- if self .output [o ] ~= 0 then
48
- self .temp :copy (input ):add (- 1 ,self .weight :select (2 ,o ))
49
- self .temp :mul (gradOutput [o ]/ self .output [o ])
50
- self .gradInput :add (self .temp )
51
- end
92
+ if not self .gradInput then
93
+ return
94
+ end
95
+
96
+ self ._div = self ._div or input .new ()
97
+ self ._output = self ._output or self .output .new ()
98
+ self ._gradOutput = self ._gradOutput or input .new ()
99
+ self ._expand3 = self ._expand3 or input .new ()
100
+
101
+ if not self .fastBackward then
102
+ self :updateOutput (input )
103
+ end
104
+
105
+ local inputSize , outputSize = self .weight :size (1 ), self .weight :size (2 )
106
+
107
+ --[[
108
+ dy_j -2 * (w_j - x) x - w_j
109
+ ---- = --------------- = -------
110
+ dx 2 || w_j - x || y_j
111
+ --]]
112
+
113
+ -- to prevent div by zero (NaN) bugs
114
+ self ._output :resizeAs (self .output ):copy (self .output ):add (0.0000001 )
115
+ view (self ._gradOutput , gradOutput , gradOutput :size ())
116
+ self ._div :cdiv (gradOutput , self ._output )
117
+ if input :dim () == 1 then
118
+ self ._div :resize (1 , outputSize )
119
+ self ._expand3 :expandAs (self ._div , self .weight )
120
+
121
+ if torch .type (input ) == ' torch.CudaTensor' then
122
+ self ._repeat2 :resizeAs (self ._expand3 ):copy (self ._expand3 )
123
+ self ._repeat2 :cmul (self ._repeat )
124
+ else
125
+ self ._repeat2 :cmul (self ._repeat , self ._expand3 )
126
+ end
127
+
128
+ self .gradInput :sum (self ._repeat2 , 2 )
129
+ self .gradInput :resizeAs (input )
130
+ elseif input :dim () == 2 then
131
+ local batchSize = input :size (1 )
132
+
133
+ self ._div :resize (batchSize , 1 , outputSize )
134
+ self ._expand3 :expand (self ._div , batchSize , inputSize , outputSize )
135
+
136
+ if torch .type (input ) == ' torch.CudaTensor' then
137
+ self ._repeat2 :resizeAs (self ._expand3 ):copy (self ._expand3 )
138
+ self ._repeat2 :cmul (self ._repeat )
139
+ else
140
+ self ._repeat2 :cmul (self ._repeat , self ._expand3 )
52
141
end
53
- return self .gradInput
142
+
143
+ self .gradInput :sum (self ._repeat2 , 3 )
144
+ self .gradInput :resizeAs (input )
145
+ else
146
+ error " 1D or 2D input expected"
54
147
end
148
+
149
+ return self .gradInput
55
150
end
56
151
57
152
function Euclidean :accGradParameters (input , gradOutput , scale )
58
- self : updateOutput ( input )
153
+ local inputSize , outputSize = self . weight : size ( 1 ), self . weight : size ( 2 )
59
154
scale = scale or 1
60
- for o = 1 ,self .weight :size (2 ) do
61
- if self .output [o ] ~= 0 then
62
- self .temp :copy (self .weight :select (2 ,o )):add (- 1 ,input )
63
- self .temp :mul (gradOutput [o ]/ self .output [o ])
64
- self .gradWeight :select (2 ,o ):add (scale , self .temp )
65
- end
155
+
156
+ --[[
157
+ dy_j 2 * (w_j - x) w_j - x
158
+ ---- = --------------- = -------
159
+ dw_j 2 || w_j - x || y_j
160
+ --]]
161
+ -- assumes a preceding call to updateGradInput
162
+ if input :dim () == 1 then
163
+ self .gradWeight :add (- scale , self ._repeat2 )
164
+ elseif input :dim () == 2 then
165
+ self ._sum = self ._sum or input .new ()
166
+ self ._sum :sum (self ._repeat2 , 1 )
167
+ self ._sum :resize (inputSize , outputSize )
168
+ self .gradWeight :add (- scale , self ._sum )
169
+ else
170
+ error " 1D or 2D input expected"
171
+ end
172
+ end
173
+
174
+ function Euclidean :type (type )
175
+ if type then
176
+ -- prevent premature memory allocations
177
+ self ._input = nil
178
+ self ._output = nil
179
+ self ._gradOutput = nil
180
+ self ._weight = nil
181
+ self ._div = nil
182
+ self ._sum = nil
183
+ self ._expand = nil
184
+ self ._expand2 = nil
185
+ self ._expand3 = nil
186
+ self ._repeat = nil
187
+ self ._repeat2 = nil
66
188
end
189
+ return parent .type (self , type )
67
190
end
0 commit comments