-
Notifications
You must be signed in to change notification settings - Fork 0
/
SpatialUpSamplingNearest.lua
58 lines (50 loc) · 1.93 KB
/
SpatialUpSamplingNearest.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
local SpatialUpSamplingNearest, parent = torch.class('nn.SpatialUpSamplingNearest', 'nn.Module')
--[[
Applies a 2D up-sampling over an input image composed of several input planes.
The upsampling is done using the simple nearest neighbor technique.
The Y and X dimensions are assumed to be the last 2 tensor dimensions. For
instance, if the tensor is 4D, then dim 3 is the y dimension and dim 4 is the x.
owidth = width*scale_factor
oheight = height*scale_factor
--]]
function SpatialUpSamplingNearest:__init(scale)
parent.__init(self)
self.scale_factor = scale
if self.scale_factor < 1 then
error('scale_factor must be greater than 1')
end
if math.floor(self.scale_factor) ~= self.scale_factor then
error('scale_factor must be integer')
end
self.inputSize = torch.LongStorage(4)
self.outputSize = torch.LongStorage(4)
self.usage = nil
end
function SpatialUpSamplingNearest:updateOutput(input)
if input:dim() ~= 4 and input:dim() ~= 3 then
error('SpatialUpSamplingNearest only support 3D or 4D tensors')
end
-- Copy the input size
local xdim = input:dim()
local ydim = input:dim() - 1
for i = 1, input:dim() do
self.inputSize[i] = input:size(i)
self.outputSize[i] = input:size(i)
end
self.outputSize[ydim] = self.outputSize[ydim] * self.scale_factor
self.outputSize[xdim] = self.outputSize[xdim] * self.scale_factor
-- Resize the output if needed
if input:dim() == 3 then
self.output:resize(self.outputSize[1], self.outputSize[2],
self.outputSize[3])
else
self.output:resize(self.outputSize)
end
input.nn.SpatialUpSamplingNearest_updateOutput(self, input)
return self.output
end
function SpatialUpSamplingNearest:updateGradInput(input, gradOutput)
self.gradInput:resizeAs(input)
input.nn.SpatialUpSamplingNearest_updateGradInput(self, input, gradOutput)
return self.gradInput
end