Skip to content

Commit

Permalink
Convert upsampling Functions to new style (pytorch#2372)
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored and soumith committed Aug 12, 2017
1 parent 641e582 commit cd5275e
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 103 deletions.
209 changes: 110 additions & 99 deletions torch/nn/_functions/thnn/upsampling.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,72 @@
from numbers import Integral
import torch
from torch.autograd import Function
from torch.autograd.function import Function, once_differentiable
from torch._thnn import type2backend

from . import _all_functions
from ...modules.utils import _pair, _triple


class _UpsamplingBase(Function):
def _check_size_scale_factor(size, scale_factor):
if size is None and scale_factor is None:
raise ValueError('either size or scale_factor should be defined')
if scale_factor is not None and not isinstance(scale_factor, (Integral, tuple)):
raise ValueError('scale_factor must be of integer type or a tuple of integer types')

def __init__(self, size=None, scale_factor=None):
super(_UpsamplingBase, self).__init__()
if size is None and scale_factor is None:
raise ValueError('either size or scale_factor should be defined')
if scale_factor is not None and not isinstance(scale_factor, (Integral, tuple)):
raise ValueError('scale_factor must be of integer type or a tuple of integer types')
self.size = size
self.scale_factor = scale_factor

class UpsamplingNearest2d(Function):

class UpsamplingNearest2d(_UpsamplingBase):
@staticmethod
def forward(ctx, input, size=None, scale_factor=None):
assert input.dim() == 4

def __init__(self, size=None, scale_factor=None):
super(UpsamplingNearest2d, self).__init__(size, scale_factor)
_check_size_scale_factor(size, scale_factor)

if self.scale_factor is not None and not isinstance(scale_factor, Integral):
raise ValueError('scale_factor must be a single Integer value for nearest neighbor sampling')
ctx.size = size
ctx.scale_factor = scale_factor

def forward(self, input):
assert input.dim() == 4
if ctx.scale_factor is not None and not isinstance(ctx.scale_factor, Integral):
raise ValueError('scale_factor must be a single Integer value for nearest neighbor sampling')

if self.scale_factor is None:
if (self.size[0] % input.size(2) != 0 or
self.size[1] % input.size(3) != 0):
if ctx.scale_factor is None:
if (ctx.size[0] % input.size(2) != 0 or
ctx.size[1] % input.size(3) != 0):
raise RuntimeError("output size specified in UpsamplingNearest "
"({}) has to be divisible by the input size, but got: "
"{}".format('x'.join(map(str, self.size)),
"{}".format('x'.join(map(str, ctx.size)),
'x'.join(map(str, input.size()))))
self.scale_factor = self.size[0] // input.size(2)
if self.scale_factor != self.size[1] // input.size(3):
ctx.scale_factor = ctx.size[0] // input.size(2)
if ctx.scale_factor != ctx.size[1] // input.size(3):
raise RuntimeError("input aspect ratio doesn't match the "
"output ratio")

output = input.new()
backend = type2backend[type(input)]
self.save_for_backward(input)
ctx.save_for_backward(input)
backend.SpatialUpSamplingNearest_updateOutput(
backend.library_state,
input,
output,
self.scale_factor
ctx.scale_factor
)
return output

def backward(self, grad_output):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
assert grad_output.dim() == 4

input, = self.saved_tensors
input, = ctx.saved_tensors
grad_input = grad_output.new()
backend = type2backend[type(input)]
backend.SpatialUpSamplingNearest_updateGradInput(
backend.library_state,
input,
grad_output,
grad_input,
self.scale_factor
ctx.scale_factor
)
return grad_input
return grad_input, None, None


def _check_linear_scale_factor(scale_factor, dim=2):
Expand All @@ -87,38 +87,41 @@ def _check_linear_scale_factor(scale_factor, dim=2):
return scale_factor


class UpsamplingBilinear2d(_UpsamplingBase):
class UpsamplingBilinear2d(Function):

def __init__(self, size=None, scale_factor=None):
super(UpsamplingBilinear2d, self).__init__(size, scale_factor)
@staticmethod
def forward(ctx, input, size=None, scale_factor=None):
assert input.dim() == 4

if self.scale_factor is not None:
self.scale_factor = _check_linear_scale_factor(self.scale_factor, dim=2)
ctx.size = size
ctx.scale_factor = scale_factor

def forward(self, input):
assert input.dim() == 4
if ctx.scale_factor is not None:
ctx.scale_factor = _check_linear_scale_factor(ctx.scale_factor, dim=2)

if self.scale_factor is not None:
self.output_size = (
input.size(2) * self.scale_factor[0],
input.size(3) * self.scale_factor[1],
if ctx.scale_factor is not None:
ctx.output_size = (
input.size(2) * ctx.scale_factor[0],
input.size(3) * ctx.scale_factor[1],
)
else:
self.output_size = self.size
ctx.output_size = ctx.size

self.input_size = input.size()
ctx.input_size = input.size()
output = input.new()
backend = type2backend[type(input)]
backend.SpatialUpSamplingBilinear_updateOutput(
backend.library_state,
input,
output,
self.output_size[0],
self.output_size[1],
ctx.output_size[0],
ctx.output_size[1],
)
return output

def backward(self, grad_output):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
assert grad_output.dim() == 4

grad_output = grad_output.contiguous()
Expand All @@ -128,94 +131,102 @@ def backward(self, grad_output):
backend.library_state,
grad_output,
grad_input,
self.input_size[0],
self.input_size[1],
self.input_size[2],
self.input_size[3],
self.output_size[0],
self.output_size[1],
ctx.input_size[0],
ctx.input_size[1],
ctx.input_size[2],
ctx.input_size[3],
ctx.output_size[0],
ctx.output_size[1],
)
return grad_input

return grad_input, None, None

class UpsamplingNearest3d(_UpsamplingBase):
def __init__(self, size=None, scale_factor=None):
super(UpsamplingNearest3d, self).__init__(size, scale_factor)

if self.scale_factor is not None and not isinstance(scale_factor, Integral):
raise ValueError('scale_factor must be a single Integer value for nearest neighbor sampling')
class UpsamplingNearest3d(Function):

def forward(self, input):
@staticmethod
def forward(ctx, input, size=None, scale_factor=None):
assert input.dim() == 5

if self.scale_factor is None:
if (self.size[0] % input.size(2) != 0 or self.size[1] % input.size(3) != 0 or
self.size[2] % input.size(4) != 0):
ctx.size = size
ctx.scale_factor = scale_factor

if ctx.scale_factor is not None and not isinstance(ctx.scale_factor, Integral):
raise ValueError('scale_factor must be a single Integer value for nearest neighbor sampling')

if ctx.scale_factor is None:
if (ctx.size[0] % input.size(2) != 0 or ctx.size[1] % input.size(3) != 0 or
ctx.size[2] % input.size(4) != 0):
raise RuntimeError("output size specified in UpSamplingNearest "
"({}) has to be divisible by the input size, but got: "
"{}".format('x'.join(map(str, self.size)),
"{}".format('x'.join(map(str, ctx.size)),
'x'.join(map(str, input.size()))))
self.scale_factor = self.size[0] // input.size(2)
if (self.scale_factor != self.size[1] // input.size(3) or
self.scale_factor != self.size[2] // input.size(4)):
ctx.scale_factor = ctx.size[0] // input.size(2)
if (ctx.scale_factor != ctx.size[1] // input.size(3) or
ctx.scale_factor != ctx.size[2] // input.size(4)):
raise RuntimeError("input aspect ratio doesn't match the "
"output ratio")

output = input.new()
backend = type2backend[type(input)]
self.save_for_backward(input)
ctx.save_for_backward(input)
backend.VolumetricUpSamplingNearest_updateOutput(backend.library_state,
input,
output,
self.scale_factor)
ctx.scale_factor)
return output

def backward(self, grad_output):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
assert grad_output.dim() == 5
input, = self.saved_tensors
input, = ctx.saved_tensors
grad_input = grad_output.new()
backend = type2backend[type(input)]
backend.VolumetricUpSamplingNearest_updateGradInput(backend.library_state,
input,
grad_output,
grad_input,
self.scale_factor)
return grad_input
ctx.scale_factor)
return grad_input, None, None


class UpsamplingTrilinear3d(_UpsamplingBase):
def __init__(self, size=None, scale_factor=None):
super(UpsamplingTrilinear3d, self).__init__(size, scale_factor)
class UpsamplingTrilinear3d(Function):

if self.scale_factor is not None:
self.scale_factor = _check_linear_scale_factor(self.scale_factor, dim=3)

def forward(self, input):
@staticmethod
def forward(ctx, input, size=None, scale_factor=None):
assert input.dim() == 5

if self.scale_factor is not None:
self.output_size = (
input.size(2) * self.scale_factor[0],
input.size(3) * self.scale_factor[1],
input.size(4) * self.scale_factor[2],
ctx.size = size
ctx.scale_factor = scale_factor

if ctx.scale_factor is not None:
ctx.scale_factor = _check_linear_scale_factor(ctx.scale_factor, dim=3)

if ctx.scale_factor is not None:
ctx.output_size = (
input.size(2) * ctx.scale_factor[0],
input.size(3) * ctx.scale_factor[1],
input.size(4) * ctx.scale_factor[2],
)
else:
self.output_size = self.size
ctx.output_size = ctx.size

self.input_size = input.size()
ctx.input_size = input.size()
output = input.new()
backend = type2backend[type(input)]
backend.VolumetricUpSamplingTrilinear_updateOutput(
backend.library_state,
input,
output,
self.output_size[0],
self.output_size[1],
self.output_size[2]
ctx.output_size[0],
ctx.output_size[1],
ctx.output_size[2]
)
return output

def backward(self, grad_output):
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
assert grad_output.dim() == 5

grad_output = grad_output.contiguous()
Expand All @@ -225,16 +236,16 @@ def backward(self, grad_output):
backend.library_state,
grad_output,
grad_input,
self.input_size[0],
self.input_size[1],
self.input_size[2],
self.input_size[3],
self.input_size[4],
self.output_size[0],
self.output_size[1],
self.output_size[2]
ctx.input_size[0],
ctx.input_size[1],
ctx.input_size[2],
ctx.input_size[3],
ctx.input_size[4],
ctx.output_size[0],
ctx.output_size[1],
ctx.output_size[2]
)
return grad_input
return grad_input, None, None


_all_functions.append(UpsamplingNearest2d)
Expand Down
8 changes: 4 additions & 4 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,17 +913,17 @@ def upsample(input, size=None, scale_factor=None, mode='nearest'):
'nearest' | 'bilinear' | 'trilinear'. Default: 'nearest'
"""
if input.dim() == 4 and mode == 'nearest':
return _functions.thnn.UpsamplingNearest2d(_pair(size), scale_factor)(input)
return _functions.thnn.UpsamplingNearest2d.apply(input, _pair(size), scale_factor)
elif input.dim() == 5 and mode == 'nearest':
return _functions.thnn.UpsamplingNearest3d(_triple(size), scale_factor)(input)
return _functions.thnn.UpsamplingNearest3d.apply(input, _triple(size), scale_factor)
elif input.dim() == 4 and mode == 'bilinear':
return _functions.thnn.UpsamplingBilinear2d(_pair(size), scale_factor)(input)
return _functions.thnn.UpsamplingBilinear2d.apply(input, _pair(size), scale_factor)
elif input.dim() == 4 and mode == 'trilinear':
raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input")
elif input.dim() == 5 and mode == 'bilinear':
raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input")
elif input.dim() == 5 and mode == 'trilinear':
return _functions.thnn.UpsamplingTrilinear3d(_triple(size), scale_factor)(input)
return _functions.thnn.UpsamplingTrilinear3d.apply(input, _triple(size), scale_factor)
else:
raise NotImplementedError("Input Error: Only 4D and 5D input Tensors supported"
" (got {}D) for the modes: nearest | bilinear | trilinear"
Expand Down

0 comments on commit cd5275e

Please sign in to comment.