Skip to content

Commit dace3a2

Browse files
Euclidean
1 parent 27acf63 commit dace3a2

File tree

2 files changed

+191
-23
lines changed

2 files changed

+191
-23
lines changed

Euclidean.lua

+144-21
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ function Euclidean:__init(inputSize,outputSize)
99
-- state
1010
self.gradInput:resize(inputSize)
1111
self.output:resize(outputSize)
12-
self.temp = torch.Tensor(inputSize)
12+
13+
self.fastBackward = true
1314

1415
self:reset()
1516
end
@@ -31,37 +32,159 @@ function Euclidean:reset(stdv)
3132
end
3233
end
3334

35+
local function view(res, src, ...)
36+
local args = {...}
37+
if src:isContiguous() then
38+
res:view(src, unpack(args))
39+
else
40+
res:reshape(src, unpack(args))
41+
end
42+
end
43+
3444
function Euclidean:updateOutput(input)
35-
self.output:zero()
36-
for o = 1,self.weight:size(2) do
37-
self.output[o] = input:dist(self.weight:select(2,o))
45+
-- lazy initialize buffers
46+
self._input = self._input or input.new()
47+
self._weight = self._weight or self.weight.new()
48+
self._expand = self._expand or self.output.new()
49+
self._expand2 = self._expand2 or self.output.new()
50+
self._repeat = self._repeat or self.output.new()
51+
self._repeat2 = self._repeat2 or self.output.new()
52+
53+
local inputSize, outputSize = self.weight:size(1), self.weight:size(2)
54+
55+
-- y_j = || w_j - x || = || x - w_j ||
56+
if input:dim() == 1 then
57+
view(self._input, input, inputSize, 1)
58+
self._expand:expandAs(self._input, self.weight)
59+
self._repeat:resizeAs(self._expand):copy(self._expand)
60+
self._repeat:add(-1, self.weight)
61+
self.output:norm(self._repeat, 2, 1)
62+
self.output:resize(outputSize)
63+
elseif input:dim() == 2 then
64+
local batchSize = input:size(1)
65+
66+
view(self._input, input, batchSize, inputSize, 1)
67+
self._expand:expand(self._input, batchSize, inputSize, outputSize)
68+
-- make the expanded tensor contiguous (requires lots of memory)
69+
self._repeat:resizeAs(self._expand):copy(self._expand)
70+
71+
self._weight:view(self.weight, 1, inputSize, outputSize)
72+
self._expand2:expandAs(self._weight, self._repeat)
73+
74+
if torch.type(input) == 'torch.CudaTensor' then
75+
-- requires lots of memory, but minimizes cudaMallocs and loops
76+
self._repeat2:resizeAs(self._expand2):copy(self._expand2)
77+
self._repeat:add(-1, self._repeat2)
78+
else
79+
self._repeat:add(-1, self._expand2)
80+
end
81+
82+
self.output:norm(self._repeat, 2, 2)
83+
self.output:resize(batchSize, outputSize)
84+
else
85+
error"1D or 2D input expected"
3886
end
87+
3988
return self.output
4089
end
4190

4291
function Euclidean:updateGradInput(input, gradOutput)
43-
self:updateOutput(input)
44-
if self.gradInput then
45-
self.gradInput:zero()
46-
for o = 1,self.weight:size(2) do
47-
if self.output[o] ~= 0 then
48-
self.temp:copy(input):add(-1,self.weight:select(2,o))
49-
self.temp:mul(gradOutput[o]/self.output[o])
50-
self.gradInput:add(self.temp)
51-
end
92+
if not self.gradInput then
93+
return
94+
end
95+
96+
self._div = self._div or input.new()
97+
self._output = self._output or self.output.new()
98+
self._gradOutput = self._gradOutput or input.new()
99+
self._expand3 = self._expand3 or input.new()
100+
101+
if not self.fastBackward then
102+
self:updateOutput(input)
103+
end
104+
105+
local inputSize, outputSize = self.weight:size(1), self.weight:size(2)
106+
107+
--[[
108+
dy_j -2 * (w_j - x) x - w_j
109+
---- = --------------- = -------
110+
dx 2 || w_j - x || y_j
111+
--]]
112+
113+
-- to prevent div by zero (NaN) bugs
114+
self._output:resizeAs(self.output):copy(self.output):add(0.0000001)
115+
view(self._gradOutput, gradOutput, gradOutput:size())
116+
self._div:cdiv(gradOutput, self._output)
117+
if input:dim() == 1 then
118+
self._div:resize(1, outputSize)
119+
self._expand3:expandAs(self._div, self.weight)
120+
121+
if torch.type(input) == 'torch.CudaTensor' then
122+
self._repeat2:resizeAs(self._expand3):copy(self._expand3)
123+
self._repeat2:cmul(self._repeat)
124+
else
125+
self._repeat2:cmul(self._repeat, self._expand3)
126+
end
127+
128+
self.gradInput:sum(self._repeat2, 2)
129+
self.gradInput:resizeAs(input)
130+
elseif input:dim() == 2 then
131+
local batchSize = input:size(1)
132+
133+
self._div:resize(batchSize, 1, outputSize)
134+
self._expand3:expand(self._div, batchSize, inputSize, outputSize)
135+
136+
if torch.type(input) == 'torch.CudaTensor' then
137+
self._repeat2:resizeAs(self._expand3):copy(self._expand3)
138+
self._repeat2:cmul(self._repeat)
139+
else
140+
self._repeat2:cmul(self._repeat, self._expand3)
52141
end
53-
return self.gradInput
142+
143+
self.gradInput:sum(self._repeat2, 3)
144+
self.gradInput:resizeAs(input)
145+
else
146+
error"1D or 2D input expected"
54147
end
148+
149+
return self.gradInput
55150
end
56151

57152
function Euclidean:accGradParameters(input, gradOutput, scale)
58-
self:updateOutput(input)
153+
local inputSize, outputSize = self.weight:size(1), self.weight:size(2)
59154
scale = scale or 1
60-
for o = 1,self.weight:size(2) do
61-
if self.output[o] ~= 0 then
62-
self.temp:copy(self.weight:select(2,o)):add(-1,input)
63-
self.temp:mul(gradOutput[o]/self.output[o])
64-
self.gradWeight:select(2,o):add(scale, self.temp)
65-
end
155+
156+
--[[
157+
dy_j 2 * (w_j - x) w_j - x
158+
---- = --------------- = -------
159+
dw_j 2 || w_j - x || y_j
160+
--]]
161+
-- assumes a preceding call to updateGradInput
162+
if input:dim() == 1 then
163+
self.gradWeight:add(-scale, self._repeat2)
164+
elseif input:dim() == 2 then
165+
self._sum = self._sum or input.new()
166+
self._sum:sum(self._repeat2, 1)
167+
self._sum:resize(inputSize, outputSize)
168+
self.gradWeight:add(-scale, self._sum)
169+
else
170+
error"1D or 2D input expected"
171+
end
172+
end
173+
174+
function Euclidean:type(type)
175+
if type then
176+
-- prevent premature memory allocations
177+
self._input = nil
178+
self._output = nil
179+
self._gradOutput = nil
180+
self._weight = nil
181+
self._div = nil
182+
self._sum = nil
183+
self._expand = nil
184+
self._expand2 = nil
185+
self._expand3 = nil
186+
self._repeat = nil
187+
self._repeat2 = nil
66188
end
189+
return parent.type(self, type)
67190
end

test.lua

+47-2
Original file line numberDiff line numberDiff line change
@@ -480,9 +480,54 @@ end
480480
function nntest.Euclidean()
481481
local ini = math.random(5,7)
482482
local inj = math.random(5,7)
483-
local input = torch.Tensor(ini):zero()
483+
local input = torch.randn(ini)
484+
local gradOutput = torch.randn(inj)
484485
local module = nn.Euclidean(ini,inj)
485-
486+
local output = module:forward(input):clone()
487+
488+
local output2 = torch.Tensor(inj):zero()
489+
for o = 1,module.weight:size(2) do
490+
output2[o] = input:dist(module.weight:select(2,o))
491+
end
492+
mytester:assertTensorEq(output, output2, 0.000001, 'Euclidean forward 1D err')
493+
494+
local input2 = torch.randn(8, ini)
495+
input2[2]:copy(input)
496+
local output2 = module:forward(input2)
497+
mytester:assertTensorEq(output2[2], output, 0.000001, 'Euclidean forward 2D err')
498+
499+
local output = module:forward(input):clone()
500+
module:zeroGradParameters()
501+
local gradInput = module:backward(input, gradOutput, 1):clone()
502+
local gradInput2 = torch.zeros(ini)
503+
local temp = input:clone()
504+
for o = 1,module.weight:size(2) do
505+
temp:copy(input)
506+
temp:add(-1,module.weight:select(2,o))
507+
temp:mul(gradOutput[o]/output[o])
508+
gradInput2:add(temp)
509+
end
510+
mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'Euclidean updateGradInput 1D err')
511+
512+
local gradWeight = module.gradWeight:clone():zero()
513+
for o = 1,module.weight:size(2) do
514+
temp:copy(module.weight:select(2,o)):add(-1,input)
515+
temp:mul(gradOutput[o]/output[o])
516+
gradWeight:select(2,o):add(1, temp)
517+
end
518+
mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'Euclidean accGradParameters 1D err')
519+
520+
local input2 = input:view(1, -1):repeatTensor(8, 1)
521+
local gradOutput2 = gradOutput:view(1, -1):repeatTensor(8, 1)
522+
local output2 = module:forward(input2)
523+
module:zeroGradParameters()
524+
local gradInput2 = module:backward(input2, gradOutput2, 1/8)
525+
mytester:assertTensorEq(gradInput2[2], gradInput, 0.000001, 'Euclidean updateGradInput 2D err')
526+
527+
mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'Euclidean accGradParameters 2D err')
528+
529+
input:zero()
530+
module.fastBackward = false
486531
local err = jac.testJacobian(module,input)
487532
mytester:assertlt(err,precision, 'error on state ')
488533

0 commit comments

Comments
 (0)