Skip to content

Commit

Permalink
Fix possible division by zero in Normalize
Browse files Browse the repository at this point in the history
When p < 2, if the input is 0 there was division by zero. Add a small dampening factor to avoid this. Also, small optimizations for different p.
  • Loading branch information
fmassa committed Oct 24, 2015

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent d1d20db commit 08195f3
Showing 2 changed files with 20 additions and 3 deletions.
19 changes: 18 additions & 1 deletion Normalize.lua
Original file line number Diff line number Diff line change
@@ -53,11 +53,28 @@ function Normalize:updateGradInput(input, gradOutput)
gradOutput = gradOutput:view(n,d,1)
self._gradInput:cmul(self.normp:view(n,1,1):expand(n,d,1), gradOutput)

-- small optimizations for different p
-- buffer = input*|input|^(p-2)
if self.p % 2 ~= 0 then
-- for non-even p, need to add absolute value
if self.p < 2 then
-- add eps to avoid possible division by 0
self.buffer:abs(input):add(self.eps):pow(self.p-2):cmul(input)
else
self.buffer:abs(input):pow(self.p-2):cmul(input)
end
elseif self.p == 2 then
-- special case for p == 2, pow(x,0) = 1
self.buffer:copy(input)
else
-- p is even and > 2, pow(x,p) is always positive
self.buffer:pow(input,self.p-2):cmul(input)
end

-- compute cross term in two steps
self.cross = self.cross or input.new()
self.cross:resize(n,1,1)

self.buffer:abs(input):pow(self.p-2):cmul(input)
local b1 = self.buffer:view(n,d,1)
local b2 = input:view(n,1,d)
-- instead of having a huge temporary matrix (b1*b2),
4 changes: 2 additions & 2 deletions test.lua
Original file line number Diff line number Diff line change
@@ -438,7 +438,7 @@ end
function nntest.Normalize()
-- compare forward against torch implementation
-- and check gradient
for _,p in pairs({1,2,1.5}) do
for _,p in pairs({1,2,3,4,1.5}) do
local ini = math.random(3,10)
local input = torch.randn(ini)
local module = nn.Normalize(p)
@@ -452,7 +452,7 @@ function nntest.Normalize()
end

-- batch mode
for _,p in pairs({1,2,torch.uniform()*math.random(1,10)}) do
for _,p in pairs({1,2,3,4,torch.uniform()*math.random(1,10)}) do
local ini = math.random(3,5)
local inj = math.random(3,5)
local ink = math.random(3,5)

0 comments on commit 08195f3

Please sign in to comment.