Skip to content

Commit

Permalink
Ignore updateGradInput if self.gradInput is nil
Browse files Browse the repository at this point in the history
Change suggested by Natalia Gimelshein

To align behaviour on other modules (linear, spatial convolution, etc.).
This allows model definitions such as:

local lenet = nn.Sequential()
lenet:add(nn.MulConstant(0.00390625))
lenet:add(nn.SpatialConvolution(1,20,5,5,1,1,0)) -- 1*28*28 -> 20*24*24
lenet:add(nn.SpatialMaxPooling(2, 2, 2, 2)) -- 20*24*24 -> 20*12*12
lenet:add(nn.SpatialConvolution(20,50,5,5,1,1,0)) -- 20*12*12 -> 50*8*8
lenet:add(nn.SpatialMaxPooling(2,2,2,2)) --  50*8*8 -> 50*4*4
lenet:add(nn.View(-1):setNumInputDims(3))  -- 50*4*4 -> 800
lenet:add(nn.Linear(800,500))  -- 800 -> 500
lenet:add(nn.ReLU())
lenet:add(nn.Linear(500, 10))  -- 500 -> 10
lenet:add(nn.LogSoftMax())
lenet:get(1).gradInput = nil
lenet:get(2).gradInput = nil

Setting gradInput to nil on the first two layers removes unnecessary
dgrad computations and saves about 5% of compute utilization.
  • Loading branch information
gheinrich committed Oct 18, 2015
1 parent 42ef6c4 commit f290f3c
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions MulConstant.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function MulConstant:__init(constant_scalar,ip)
parent.__init(self)
assert(type(constant_scalar) == 'number', 'input is not scalar!')
self.constant_scalar = constant_scalar

-- default for inplace is false
self.inplace = ip or false
if (ip and type(ip) ~= 'boolean') then
Expand All @@ -22,18 +22,20 @@ function MulConstant:updateOutput(input)
self.output:mul(self.constant_scalar)
end
return self.output
end
end

function MulConstant:updateGradInput(input, gradOutput)
if self.inplace then
gradOutput:mul(self.constant_scalar)
self.gradInput = gradOutput
-- restore previous input value
input:div(self.constant_scalar)
else
self.gradInput:resizeAs(gradOutput)
self.gradInput:copy(gradOutput)
self.gradInput:mul(self.constant_scalar)
if self.gradInput then
if self.inplace then
gradOutput:mul(self.constant_scalar)
self.gradInput = gradOutput
-- restore previous input value
input:div(self.constant_scalar)
else
self.gradInput:resizeAs(gradOutput)
self.gradInput:copy(gradOutput)
self.gradInput:mul(self.constant_scalar)
end
return self.gradInput
end
return self.gradInput
end

0 comments on commit f290f3c

Please sign in to comment.