From 08195f320495a50b480628901fd96071fca8d18f Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sat, 24 Oct 2015 22:24:52 +0200 Subject: [PATCH] Fix possible division by zero in Normalize When p < 2, if the input is 0 there was division by zero. Add a small dampening factor to avoid this. Also, small optimizations for different p. --- Normalize.lua | 19 ++++++++++++++++++- test.lua | 4 ++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/Normalize.lua b/Normalize.lua index 304caa047..58dc11d9d 100644 --- a/Normalize.lua +++ b/Normalize.lua @@ -53,11 +53,28 @@ function Normalize:updateGradInput(input, gradOutput) gradOutput = gradOutput:view(n,d,1) self._gradInput:cmul(self.normp:view(n,1,1):expand(n,d,1), gradOutput) + -- small optimizations for different p + -- buffer = input*|input|^(p-2) + if self.p % 2 ~= 0 then + -- for non-even p, need to add absolute value + if self.p < 2 then + -- add eps to avoid possible division by 0 + self.buffer:abs(input):add(self.eps):pow(self.p-2):cmul(input) + else + self.buffer:abs(input):pow(self.p-2):cmul(input) + end + elseif self.p == 2 then + -- special case for p == 2, pow(x,0) = 1 + self.buffer:copy(input) + else + -- p is even and > 2, pow(x,p) is always positive + self.buffer:pow(input,self.p-2):cmul(input) + end + -- compute cross term in two steps self.cross = self.cross or input.new() self.cross:resize(n,1,1) - self.buffer:abs(input):pow(self.p-2):cmul(input) local b1 = self.buffer:view(n,d,1) local b2 = input:view(n,1,d) -- instead of having a huge temporary matrix (b1*b2), diff --git a/test.lua b/test.lua index 03fef0b48..9c02743d3 100644 --- a/test.lua +++ b/test.lua @@ -438,7 +438,7 @@ end function nntest.Normalize() -- compare forward against torch implementation -- and check gradient - for _,p in pairs({1,2,1.5}) do + for _,p in pairs({1,2,3,4,1.5}) do local ini = math.random(3,10) local input = torch.randn(ini) local module = nn.Normalize(p) @@ -452,7 +452,7 @@ function nntest.Normalize() end -- batch mode - for _,p in pairs({1,2,torch.uniform()*math.random(1,10)}) do + for _,p in pairs({1,2,3,4,torch.uniform()*math.random(1,10)}) do local ini = math.random(3,5) local inj = math.random(3,5) local ink = math.random(3,5)