Skip to content

Commit

Permalink
Change MaxUnpool to accept tensors with 0-dim batch sizes. (pytorch#6…
Browse files Browse the repository at this point in the history
…4082)

Summary:
Part of the fix for pytorch#38115.

Changes the `MaxUnpool` module to work with 0-dimensions batch sizes.

Pull Request resolved: pytorch#64082

Reviewed By: mrshenli

Differential Revision: D30793907

Pulled By: jbschlosser

fbshipit-source-id: d21aa665be5aa18f592b39ef7b4e3cbc632e21ed
  • Loading branch information
v0dro authored and facebook-github-bot committed Sep 8, 2021
1 parent ba8c1fc commit 7205ca0
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 32 deletions.
38 changes: 29 additions & 9 deletions aten/src/ATen/native/MaxUnpooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ Tensor& max_unpooling2d_forward_out_cpu(
self_.sizes() == indices_.sizes(),
"Shape of indices should match shape of input");

TORCH_CHECK(self_.numel() > 0, "Input must be non-empty");
for (int64_t i = 1; i < self_.ndimension(); ++i) {
TORCH_CHECK(self_.size(i) > 0, "max_unpooling2d_forward_out_cpu(): ",
"Expected input to have non-zero size for non-batch dimensions, but got ",
self_.sizes(), " with dimension ", i , " being empty.");
}

auto memory_format = self_.suggest_memory_format();
auto self = self_.contiguous(memory_format);
Expand All @@ -41,7 +45,10 @@ Tensor& max_unpooling2d_forward_out_cpu(
}
output.zero_();

max_unpool2d_kernel(kCPU, output, self, indices);
if (output.numel() != 0) {
max_unpool2d_kernel(kCPU, output, self, indices);
}

return output;
};

Expand All @@ -60,7 +67,8 @@ static void max_unpooling3d_shape_check(
const Tensor& indices,
IntArrayRef output_size,
IntArrayRef stride,
IntArrayRef padding) {
IntArrayRef padding,
const char *fn_name) {
int64_t oT = output_size[0];
int64_t oH = output_size[1];
int64_t oW = output_size[2];
Expand All @@ -84,7 +92,11 @@ static void max_unpooling3d_shape_check(
input.sizes() == indices.sizes(),
"Shape of indices should match shape of input");

TORCH_CHECK(input.numel() > 0, "Input must be non-empty");
for (int64_t i = 1; i < input.ndimension(); ++i) {
TORCH_CHECK(input.size(i) > 0, fn_name,
": Expected input to have non-zero size for non-batch dimensions, but got ",
input.sizes(), " with dimension ", i , " being empty.");
}

TORCH_CHECK(
stride[0] > 0 && stride[1] > 0 && stride[2] > 0,
Expand Down Expand Up @@ -144,16 +156,18 @@ Tensor& max_unpooling3d_forward_out_cpu(const Tensor& self_,
auto indices = indices_.contiguous();

max_unpooling3d_shape_check(
self_, Tensor(), indices_, output_size, stride, padding);
self_, Tensor(), indices_, output_size, stride, padding, "max_unpooling3d_forward_out_cpu()");

if (self_.ndimension() == 5) {
output.resize_({self.size(0), self.size(1), oT, oH, oW});
} else {
output.resize_({self.size(0), oT, oH, oW});
}
output.zero_();
if (output.numel() != 0) {
max_unpool3d_kernel(kCPU, output, self, indices);
}

max_unpool3d_kernel(kCPU, output, self, indices);
return output;
}

Expand Down Expand Up @@ -207,7 +221,10 @@ Tensor& max_unpooling2d_backward_out_cpu(const Tensor& grad_output_,
grad_output.size(dimw));
}

max_unpool2d_backward_kernel(kCPU, grad_input, grad_output, indices);
if (grad_input.numel() != 0) {
max_unpool2d_backward_kernel(kCPU, grad_input, grad_output, indices);
}

return grad_input;
}

Expand Down Expand Up @@ -240,7 +257,7 @@ Tensor& max_unpooling3d_backward_out_cpu(
int64_t dimw = ndim == 4 ? 3 : 4;

max_unpooling3d_shape_check(
self, grad_output_, indices_, output_size, stride, padding);
self, grad_output_, indices_, output_size, stride, padding, "max_unpooling3d_backward_out_cpu()");

/* get contiguous gradOutput */
auto grad_output = grad_output_.contiguous();
Expand All @@ -266,7 +283,10 @@ Tensor& max_unpooling3d_backward_out_cpu(
grad_output.size(dimw));
}

max_unpool3d_backward_kernel(kCPU, grad_input, grad_output, indices);
if (grad_input.numel() != 0) {
max_unpool3d_backward_kernel(kCPU, grad_input, grad_output, indices);
}

return grad_input;
}

Expand Down
67 changes: 44 additions & 23 deletions aten/src/ATen/native/cuda/MaxUnpooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ Tensor& max_unpooling2d_forward_out_cuda(const Tensor& self_,
checkAllSameGPU(
"max_unpooling2d_forward_out_cuda", {output_arg, self_arg, indices_arg});

TORCH_CHECK(self_.numel() > 0, "Input must be non-empty tensor");
for (int64_t i = 1; i < self_.ndimension(); ++i) {
TORCH_CHECK(self_.size(i) > 0, "max_unpooling2d_forward_out_cuda(): ",
"Expected input to have non-zero size for non-batch dimensions, but got ",
self_.sizes(), " with dimension ", i , " being empty.");
}

TORCH_CHECK(
(self_.ndimension() == 3 || self_.ndimension() == 4),
Expand Down Expand Up @@ -152,24 +156,26 @@ Tensor& max_unpooling2d_forward_out_cuda(const Tensor& self_,
output.zero_();

auto count = self.numel();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half,
self.scalar_type(), "max_unpooling2d_forward_kernel", ([&] {
max_unpooling2d_forward_kernel<<<
GET_BLOCKS(count),
CUDA_NUM_THREADS,
0,
at::cuda::getCurrentCUDAStream()>>>(
self.numel(),
self.data_ptr<scalar_t>(),
indices.data_ptr<int64_t>(),
numChannels,
inputHeight,
inputWidth,
oheight,
owidth,
output.data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
if (count != 0) {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half,
self.scalar_type(), "max_unpooling2d_forward_kernel", ([&] {
max_unpooling2d_forward_kernel<<<
GET_BLOCKS(count),
CUDA_NUM_THREADS,
0,
at::cuda::getCurrentCUDAStream()>>>(
self.numel(),
self.data_ptr<scalar_t>(),
indices.data_ptr<int64_t>(),
numChannels,
inputHeight,
inputWidth,
oheight,
owidth,
output.data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}));
}
if (self.ndimension() == 3) {
output.resize_({numChannels, oheight, owidth});
}
Expand All @@ -191,7 +197,8 @@ static void max_unpooling3d_shape_check(
const Tensor& indices,
IntArrayRef output_size,
IntArrayRef stride,
IntArrayRef padding) {
IntArrayRef padding,
const char *fn_name) {
int64_t oT = output_size[0];
int64_t oH = output_size[1];
int64_t oW = output_size[2];
Expand All @@ -215,7 +222,11 @@ static void max_unpooling3d_shape_check(
input.sizes() == indices.sizes(),
"Shape of indices should match shape of input");
TORCH_CHECK(input.numel() > 0, "Input must be non-empty");
for (int64_t i = 1; i < input.ndimension(); ++i) {
TORCH_CHECK(input.size(i) > 0, fn_name,
": Expected input to have non-zero size for non-batch dimensions, but got ",
input.sizes(), " with dimension ", i , " being empty.");
}
TORCH_CHECK(
stride[0] > 0 && stride[1] > 0 && stride[2] > 0,
Expand Down Expand Up @@ -268,7 +279,7 @@ Tensor& max_unpooling3d_forward_out_cuda(const Tensor& self_,
Tensor& output) {
TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
max_unpooling3d_shape_check(
self_, Tensor(), indices_, output_size, stride, padding);
self_, Tensor(), indices_, output_size, stride, padding, "max_unpooling3d_forward_out_cuda()");
int64_t oT = output_size[0];
int64_t oH = output_size[1];
Expand Down Expand Up @@ -318,6 +329,10 @@ Tensor& max_unpooling3d_forward_out_cuda(const Tensor& self_,
indices.size(4)});
}
if (self.numel() == 0) {
return output;
}
int totalZ = inputTime * inputSlices * batchSize;
int offsetZ = 0;
dim3 block(32, 8);
Expand Down Expand Up @@ -426,6 +441,9 @@ at::Tensor& max_unpooling2d_backward_out_cuda(const Tensor& grad_output_,
grad_input.zero_();
int64_t count = self.numel();
if (count == 0) {
return grad_input;
}
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half,
self.scalar_type(), "max_unpooling2d_backward_kernel", ([&] {
Expand Down Expand Up @@ -471,7 +489,7 @@ at::Tensor& max_unpooling3d_backward_out_cuda(const Tensor& grad_output_,
int64_t oW = output_size[2];
max_unpooling3d_shape_check(
self_, grad_output_, indices_, output_size, stride, padding);
self_, grad_output_, indices_, output_size, stride, padding, "max_unpooling3d_backward_out_cuda()");
int batchSize = 0;
int inputSlices = 0;
Expand Down Expand Up @@ -521,6 +539,9 @@ at::Tensor& max_unpooling3d_backward_out_cuda(const Tensor& grad_output_,
indices.size(3),
indices.size(4)});
}
if (grad_input.numel() == 0) {
return grad_input;
}
int totalZ = inputTime * inputSlices * batchSize;
int offsetZ = 0;
Expand Down
34 changes: 34 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13763,6 +13763,40 @@ def test_MaxPool_zero_batch_dim(self, device):
inp = torch.ones(1, 0, 50, 44, 31, device=device)
mod(inp)

@onlyOnCPUAndCUDA
def test_MaxUnpool_zero_batch_dim(self, device):
pool = torch.nn.MaxPool1d(2, stride=2, return_indices=True).to(device)
unpool = torch.nn.MaxUnpool1d(2, stride=2).to(device)
inp = torch.randn(0, 10, 10, requires_grad=True, device=device)
output, indices = pool(inp)
output.requires_grad_(True)
unpool_out = unpool(output, indices)
unpool_out.sum().backward()

self.assertEqual(inp.grad, torch.zeros_like(inp))
self.assertEqual(unpool_out, torch.zeros_like(unpool_out))

pool = torch.nn.MaxPool2d(2, stride=2, return_indices=True).to(device)
unpool = torch.nn.MaxUnpool2d(2, stride=2).to(device)
inp = torch.randn(0, 10, 10, 10, requires_grad=True, device=device)
output, indices = pool(inp)
unpool_out = unpool(output, indices)
unpool_out.sum().backward()

self.assertEqual(inp.grad, torch.zeros_like(inp))
self.assertEqual(unpool_out, torch.zeros_like(unpool_out))

pool = torch.nn.MaxPool3d(2, stride=2, return_indices=True).to(device)
unpool = torch.nn.MaxUnpool3d(2, stride=2).to(device)
inp = torch.randn(0, 10, 10, 10, 10, requires_grad=True, device=device)
output, indices = pool(inp)
output.requires_grad_(True)
unpool_out = unpool(output, indices)
unpool_out.sum().backward()

self.assertEqual(inp.grad, torch.zeros_like(inp))
self.assertEqual(unpool_out, torch.zeros_like(unpool_out))

@onlyOnCPUAndCUDA
def test_AdaptiveMaxPool_zero_batch_dim(self, device):
inp = torch.randn(0, 16, 50, device=device)
Expand Down

0 comments on commit 7205ca0

Please sign in to comment.