Skip to content

Commit

Permalink
ENH Adds no batch dim support for AvgPool1d (pytorch#61860)
Browse files Browse the repository at this point in the history
Summary:
Towards pytorch#60585

Pull Request resolved: pytorch#61860

Reviewed By: albanD

Differential Revision: D29826382

Pulled By: jbschlosser

fbshipit-source-id: 47e12073d866f0604310fc1ff270cde9907e516d
  • Loading branch information
thomasjpfan authored and facebook-github-bot committed Jul 22, 2021
1 parent 5a00152 commit 0309c57
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
6 changes: 3 additions & 3 deletions aten/src/ATen/native/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,20 @@ Tensor avg_pool1d(
if (stride.empty()) {
stride = kernel_size;
}
checkDim("avg_pool1d", TensorArg(self, "self", 1), 3);
checkDimRange("avg_pool1d", TensorArg(self, "self", 1), 2, 4 /* exclusive */);
check1d("avg_pool1d", "kernel_size", kernel_size);
check1d("avg_pool1d", "stride", stride);
check1d("avg_pool1d", "padding", padding);

auto output = at::avg_pool2d(
self.unsqueeze(2),
self.unsqueeze(-2),
{1, kernel_size[0]},
{1, stride[0]},
{0, padding[0]},
ceil_mode,
count_include_pad);

return output.squeeze(2);
return output.squeeze(-2);
}

Tensor max_pool2d(
Expand Down
4 changes: 2 additions & 2 deletions torch/nn/modules/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,8 @@ class AvgPool1d(_AvgPoolNd):
count_include_pad: when True, will include the zero-padding in the averaging calculation
Shape:
- Input: :math:`(N, C, L_{in})`
- Output: :math:`(N, C, L_{out})`, where
- Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in}`.
- Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
.. math::
L_{out} = \left\lfloor \frac{L_{in} +
Expand Down
8 changes: 8 additions & 0 deletions torch/testing/_internal/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2109,6 +2109,14 @@ def single_batch_reference_fn(input, parameters, module):
input_size=(2, 3, 6),
desc='stride_pad',
),
dict(
module_name='AvgPool1d',
constructor_args=(2,),
cpp_constructor_args='torch::nn::AvgPool1dOptions(2)',
input_size=(3, 6),
reference_fn=single_batch_reference_fn,
desc='no_batch_dim',
),
dict(
module_name='AvgPool2d',
constructor_args=((2, 2),),
Expand Down

0 comments on commit 0309c57

Please sign in to comment.