Skip to content

Commit

Permalink
SpatialAveragePooling divides by kW*kH
Browse files Browse the repository at this point in the history
  • Loading branch information
szagoruyko committed Mar 4, 2015
1 parent 7470cc8 commit cb30332
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 6 deletions.
16 changes: 14 additions & 2 deletions SpatialAveragePooling.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,26 @@ function SpatialAveragePooling:__init(kW, kH, dW, dH)
self.kH = kH
self.dW = dW or 1
self.dH = dH or 1
self.divide = true
end

function SpatialAveragePooling:updateOutput(input)
return input.nn.SpatialAveragePooling_updateOutput(self, input)
input.nn.SpatialAveragePooling_updateOutput(self, input)
-- for backward compatibility with saved models
-- which are not supposed to have "divide" field
if not self.divide then
self.output:mul(self.kW*self.kH)
end
return self.output
end

function SpatialAveragePooling:updateGradInput(input, gradOutput)
if self.gradInput then
return input.nn.SpatialAveragePooling_updateGradInput(self, input, gradOutput)
input.nn.SpatialAveragePooling_updateGradInput(self, input, gradOutput)
-- for backward compatibility
if not self.divide then
self.gradInput:mul(self.kW*self.kH)
end
return self.gradInput
end
end
1 change: 1 addition & 0 deletions SpatialLPPooling.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ function SpatialLPPooling:__init(nInputPlane, pnorm, kW, kH, dW, dH)
self:add(nn.Power(pnorm))
end
self:add(nn.SpatialAveragePooling(kW, kH, dW, dH))
self:add(nn.MulConstant(kW*kH))
if pnorm == 2 then
self:add(nn.Sqrt())
else
Expand Down
4 changes: 2 additions & 2 deletions generic/SpatialAveragePooling.c
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ static int nn_(SpatialAveragePooling_updateOutput)(lua_State *L)
ptr_input += inputWidth; /* next input line */
}
/* Update output */
*ptr_output++ += sum;
*ptr_output++ += sum/(kW*kH);
}
}
}
Expand Down Expand Up @@ -163,7 +163,7 @@ static int nn_(SpatialAveragePooling_updateGradInput)(lua_State *L)
for(ky = 0; ky < kH; ky++)
{
for(kx = 0; kx < kW; kx++)
ptr_gradInput[kx] += z;
ptr_gradInput[kx] += z/(kW*kH);
ptr_gradInput += inputWidth;
}
}
Expand Down
4 changes: 2 additions & 2 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1706,7 +1706,7 @@ function nntest.SpatialAveragePooling()
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')

local sap = nn.SpatialSubSampling(from, ki, kj, si, sj)
sap.weight:fill(1.0)
sap.weight:fill(1.0/(ki*kj))
sap.bias:fill(0.0)

local output = module:forward(input)
Expand Down Expand Up @@ -1737,7 +1737,7 @@ function nntest.SpatialAveragePooling()
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ')

local sap = nn.SpatialSubSampling(from, ki, kj, si, sj)
sap.weight:fill(1.0)
sap.weight:fill(1.0/(ki*kj))
sap.bias:fill(0.0)

local output = module:forward(input)
Expand Down

0 comments on commit cb30332

Please sign in to comment.