Skip to content

Commit

Permalink
ENH Adds no_batch_dim to FractionalMaxPool2d (pytorch#62490)
Browse files Browse the repository at this point in the history
Summary:
Towards pytorch#60585

Pull Request resolved: pytorch#62490

Reviewed By: bdhirsh

Differential Revision: D30287143

Pulled By: jbschlosser

fbshipit-source-id: 1b9dd932157f571adf3aa2c98c3c6b56ece8fa6e
  • Loading branch information
thomasjpfan authored and facebook-github-bot committed Aug 13, 2021
1 parent 61b49c8 commit c5f3ab6
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 7 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/FractionalMaxPool2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ void fractional_max_pool2d_backward_out_cuda_template(
gradInput_.size(0));
dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);

auto devIndices = indices.packed_accessor<int64_t, 4>();
auto devIndices = indices_.packed_accessor<int64_t, 4>();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradOutput.scalar_type(),
"fractional_max_pool2d_backward_out_cuda_frame",
[&] {
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/api/include/torch/nn/functional/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -767,17 +767,17 @@ inline std::tuple<Tensor, Tensor> fractional_max_pool2d_with_indices(
"fractional_max_pool2d requires specifying either ",
"an output_size or an output_ratio");
}

c10::optional<ExpandingArray<2>> output_size_ = output_size;
if (output_size_ == c10::nullopt) {
TORCH_INTERNAL_ASSERT(output_ratio != c10::nullopt);
output_size_ = {(int64_t)(input.sizes()[2] * (*output_ratio.value())[0]),
(int64_t)(input.sizes()[3] * (*output_ratio.value())[1])};
output_size_ = {(int64_t)(input.size(-2) * (*output_ratio.value())[0]),
(int64_t)(input.size(-1) * (*output_ratio.value())[1])};
}

Tensor _random_samples_ = _random_samples;
if (!_random_samples_.defined()) {
_random_samples_ = torch::rand({input.sizes()[0], input.sizes()[1], 2}, torch::TensorOptions().dtype(input.dtype()).device(input.device()));
auto n_batch = 1 ? input.dim() == 3 : input.size(0);
_random_samples_ = torch::rand({n_batch, input.size(-1), 2}, torch::TensorOptions().dtype(input.dtype()).device(input.device()));
}
return torch::fractional_max_pool2d(input, kernel_size, *output_size_, _random_samples_);
}
Expand Down
5 changes: 3 additions & 2 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,11 @@ def fractional_max_pool2d_with_indices(
if output_size is None:
assert output_ratio is not None
_output_ratio = _pair(output_ratio)
output_size = [int(input.size(2) * _output_ratio[0]), int(input.size(3) * _output_ratio[1])]
output_size = [int(input.size(-2) * _output_ratio[0]), int(input.size(-1) * _output_ratio[1])]

if _random_samples is None:
_random_samples = torch.rand(input.size(0), input.size(1), 2, dtype=input.dtype, device=input.device)
n_batch = 1 if input.dim() == 3 else input.size(0)
_random_samples = torch.rand(n_batch, input.size(-3), 2, dtype=input.dtype, device=input.device)
return torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples)


Expand Down
6 changes: 6 additions & 0 deletions torch/nn/modules/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,12 @@ class FractionalMaxPool2d(Module):
return_indices: if ``True``, will return the indices along with the outputs.
Useful to pass to :meth:`nn.MaxUnpool2d`. Default: ``False``
Shape:
- Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
- Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
:math:`(H_{out}, W_{out})=\text{output\_size}` or
:math:`(H_{out}, W_{out})=\text{output\_ratio} \times (H_{in}, W_{in})`.
Examples:
>>> # pool of square window of size=3, and target output size 13x12
>>> m = nn.FractionalMaxPool2d(3, output_size=(13, 12))
Expand Down
54 changes: 54 additions & 0 deletions torch/testing/_internal/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,56 @@ def fractional_max_pool2d_test(test_case):
fullname='FractionalMaxPool2d_size')


def fractional_max_pool2d_no_batch_dim_test(test_case, use_random_samples):
if use_random_samples:
# random_samples enables CPU and GPU checks to be consistent
random_samples = torch.empty((1, 3, 2), dtype=torch.double).uniform_()
if test_case == 'ratio':
return dict(
constructor=lambda: nn.FractionalMaxPool2d(
2, output_ratio=0.5, _random_samples=random_samples),
cpp_constructor_args='''torch::nn::FractionalMaxPool2dOptions(2)
.output_ratio(0.5)
._random_samples(random_samples)''',
input_size=(3, 5, 7),
cpp_var_map={'random_samples': random_samples},
reference_fn=single_batch_reference_fn,
fullname='FractionalMaxPool2d_ratio_no_batch_dim')
elif test_case == 'size':
return dict(
constructor=lambda: nn.FractionalMaxPool2d((2, 3), output_size=(
4, 3), _random_samples=random_samples),
cpp_constructor_args='''torch::nn::FractionalMaxPool2dOptions({2, 3})
.output_size(std::vector<int64_t>({4, 3}))
._random_samples(random_samples)''',
input_size=(3, 7, 6),
cpp_var_map={'random_samples': random_samples},
reference_fn=single_batch_reference_fn,
fullname='FractionalMaxPool2d_size_no_batch_dim')
else:
# can not check cuda because there RNG is different between cpu and cuda
if test_case == 'ratio':
return dict(
constructor=lambda: nn.FractionalMaxPool2d(
2, output_ratio=0.5),
cpp_constructor_args='''torch::nn::FractionalMaxPool2dOptions(2)
.output_ratio(0.5)''',
input_size=(3, 5, 7),
reference_fn=single_batch_reference_fn,
test_cuda=False,
fullname='FractionalMaxPool2d_ratio_no_batch_dim_no_random_samples')
elif test_case == 'size':
return dict(
constructor=lambda: nn.FractionalMaxPool2d((2, 3), output_size=(
4, 3)),
cpp_constructor_args='''torch::nn::FractionalMaxPool2dOptions({2, 3})
.output_size(std::vector<int64_t>({4, 3}))''',
input_size=(3, 7, 6),
reference_fn=single_batch_reference_fn,
test_cuda=False,
fullname='FractionalMaxPool2d_size_no_batch_dim_no_random_samples')


def fractional_max_pool3d_test(test_case):
random_samples = torch.empty((2, 4, 3), dtype=torch.double).uniform_()
if test_case == 'ratio':
Expand Down Expand Up @@ -1314,6 +1364,10 @@ def single_batch_reference_fn(input, parameters, module):
fractional_max_pool3d_test('ratio'),
fractional_max_pool3d_test('size'),
fractional_max_pool3d_test('asymsize'),
fractional_max_pool2d_no_batch_dim_test('ratio', True),
fractional_max_pool2d_no_batch_dim_test('ratio', False),
fractional_max_pool2d_no_batch_dim_test('size', True),
fractional_max_pool2d_no_batch_dim_test('size', False),
dict(
module_name='BatchNorm1d',
constructor_args=(10,),
Expand Down

0 comments on commit c5f3ab6

Please sign in to comment.