Skip to content

Commit

Permalink
[pt2] add metas for max_unpool2d and max_unpool3d (pytorch#103821)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Jul 1, 2023
1 parent f9aa004 commit c4a6f86
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 30 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/MaxUnpooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ Tensor& max_unpooling2d_forward_out_cuda(const Tensor& self_,
"Expected shape of indices to be: ", self_.sizes(), " but got: ", indices_.sizes());
TORCH_CHECK(
output_size.size() == 2,
"There should be exactly two elements (width, height) in output_size, but got ", output_size.size(), " elements.");
"There should be exactly two elements (height, width) in output_size, but got ", output_size.size(), " elements.");

int64_t dimw = 2;
int64_t dimh = 1;
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11778,25 +11778,25 @@
CPU: max_pool3d_with_indices_backward_cpu
CUDA: max_pool3d_with_indices_backward_cuda

- func: max_unpool2d.out(Tensor self, Tensor indices, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
- func: max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
dispatch:
CPU: max_unpooling2d_forward_out_cpu
CUDA: max_unpooling2d_forward_out_cuda

- func: max_unpool2d(Tensor self, Tensor indices, int[2] output_size) -> Tensor
- func: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor
python_module: nn
dispatch:
CPU: max_unpooling2d_forward_cpu
CUDA: max_unpooling2d_forward_cuda

- func: max_unpool3d.out(Tensor self, Tensor indices, int[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)
- func: max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
dispatch:
CPU: max_unpooling3d_forward_out_cpu
CUDA: max_unpooling3d_forward_out_cuda

- func: max_unpool3d(Tensor self, Tensor indices, int[3] output_size, int[3] stride, int[3] padding) -> Tensor
- func: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor
python_module: nn
dispatch:
CPU: max_unpooling3d_forward_cpu
Expand Down
6 changes: 0 additions & 6 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2836,12 +2836,6 @@ def forward(self, x):
xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st...
xfail('nn.functional.max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic m...
xfail('nn.functional.max_unpool1d', ''), # aten.max_unpool2d.default - couldn't find symbolic meta funct...
xfail('nn.functional.max_unpool1d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta ...
xfail('nn.functional.max_unpool2d', ''), # aten.max_unpool2d.default - couldn't find symbolic meta funct...
xfail('nn.functional.max_unpool2d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta ...
xfail('nn.functional.max_unpool3d', ''), # aten.max_unpool3d.default - couldn't find symbolic meta funct...
xfail('nn.functional.max_unpool3d', 'grad'), # aten.max_unpool3d.default - couldn't find symbolic meta ...
xfail('nn.functional.multi_margin_loss', ''), # could not find kernel
xfail('nn.functional.multilabel_margin_loss', ''), # could not find kernel
xfail('nn.functional.nll_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
Expand Down
10 changes: 0 additions & 10 deletions test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,9 +633,6 @@ def run_meta_crossref(
torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
torch.nn.functional.max_pool3d : {f64, f32},
torch.nn.functional.max_pool3d_with_indices : {f64, f32},
torch.nn.functional.max_unpool1d : {f64, f32},
torch.nn.functional.max_unpool2d : {f64, f32},
torch.nn.functional.max_unpool3d : {f64, f32},
torch.nn.functional.multi_margin_loss : {f64, f32},
torch.nn.functional.multilabel_margin_loss : {f64, f32},
torch.nn.functional.one_hot : {i64},
Expand Down Expand Up @@ -727,9 +724,6 @@ def run_meta_crossref(
torch.median: {f16}, # aten::median, aten::median.dim_values
torch.nn.functional.max_pool3d: {bf16, f16}, # aten::max_pool3d_with_indices
torch.nn.functional.max_pool3d_with_indices: {bf16, f16}, # aten::max_pool3d_with_indices
torch.nn.functional.max_unpool1d: {f16}, # aten::max_unpool2d
torch.nn.functional.max_unpool2d: {f16}, # aten::max_unpool2d
torch.nn.functional.max_unpool3d: {f16}, # aten::max_unpool3d
torch.nn.functional.multi_margin_loss: {bf16, f16}, # aten::multi_margin_loss
torch.nn.functional.multilabel_margin_loss: {bf16, f16}, # aten::multilabel_margin_loss_forward
torch.ormqr: {f32, f64}, # aten::ormqr, aten::ormqr.out
Expand Down Expand Up @@ -854,8 +848,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
aten.histogram.bins_tensor : {f32, f64},
aten.kthvalue.default : {i8, f64, i64, bf16, f32, i32, i16, u8},
aten.max_pool3d_with_indices.default : {f32, f64},
aten.max_unpool2d.default : {f32, f64},
aten.max_unpool3d.default : {f32, f64},
aten.median.default : {i8, f64, i64, bf16, f32, i32, i16, u8},
aten.median.dim : {i8, f64, i64, bf16, f32, i32, i16, u8},
aten.mode.default : {f16, i8, f64, i64, bf16, f32, i32, b8, i16, u8},
Expand Down Expand Up @@ -916,8 +908,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
aten.log_sigmoid_forward.default: {bf16, f16, f64, f32},
aten.log_sigmoid_forward.output : {bf16, f16, f64, f32}, # aten::log_sigmoid_forward.output
aten.max_pool3d_with_indices.default: {bf16, f16}, # aten::max_pool3d_with_indices
aten.max_unpool2d.default: {f16}, # aten::max_unpool2d
aten.max_unpool3d.default: {f16}, # aten::max_unpool3d
aten.median.default: {f16}, # aten::median
aten.median.dim: {f16}, # aten::median.dim_values
aten.multi_margin_loss.default: {bf16, f16}, # aten::multi_margin_loss
Expand Down
3 changes: 0 additions & 3 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,9 +1543,6 @@ def f(a, b, c, d, e):
xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi...
xfail('nn.functional.max_pool1d', ''), # Trying to call aten.size on a tensor with symbolic shapes.
xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d...
xfail('nn.functional.max_unpool1d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.max_unpool2d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.max_unpool3d', 'grad'), # aten.max_unpool3d.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.multi_margin_loss', ''), # Could not run 'aten::multi_margin_loss' with arguments from the...
xfail('nn.functional.multilabel_margin_loss', ''), # Could not run 'aten::multilabel_margin_loss_forward' with ...
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2257,12 +2257,12 @@
result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1)
output_differentiability: [True, False]

- name: max_unpool2d(Tensor self, Tensor indices, int[2] output_size) -> Tensor
- name: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor
self: max_pool_double_backward(grad, indices, 2)
indices: non_differentiable
result: auto_linear

- name: max_unpool3d(Tensor self, Tensor indices, int[3] output_size, int[3] stride, int[3] padding) -> Tensor
- name: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor
self: max_pool_double_backward(grad, indices, 3)
indices: non_differentiable
result: auto_linear
Expand Down
128 changes: 128 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3456,6 +3456,134 @@ def meta_max_pool2d_with_indices(
)


@register_meta(aten.max_unpool2d)
@out_wrapper()
def meta_max_unpool2d(self_, indices, output_size):
utils.alert_not_deterministic("max_unpooling2d_forward_out")

torch._check(
indices.dtype == torch.int64,
lambda: f"elements in indices should be type int64 but got: {indices.dtype}",
)
torch._check(
len(output_size) == 2,
lambda: (
f"There should be exactly two elements (height, width) in output_size, "
f"but got {len(output_size)} elements."
),
)

oheight, owidth = output_size

torch._check(
self_.ndim in (3, 4),
lambda: (
f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
f"but got a tensor with {self_.ndim} dimensions."
),
)
torch._check(
self_.shape == indices.shape,
lambda: (
f"Expected shape of indices to be same as that of the input tensor ({self_.shape}) "
f"but got indices tensor with shape: {indices.shape}"
),
)

for i in range(1, self_.ndim):
torch._check(
self_.size(i) > 0,
lambda: (
f"max_unpooling2d(): "
f"Expected input to have non-zero size for non-batch dimensions, "
f"but got {self_.shape} with dimension {i} being empty."
),
)

self = self_.contiguous()

if self_.ndim == 3:
nchannels = self.size(0)
result = self.new_empty((nchannels, oheight, owidth))
else:
nbatch = self.size(0)
nchannels = self.size(1)
result = self.new_empty((nbatch, nchannels, oheight, owidth))

return result


def _max_unpooling3d_shape_check(input, indices, output_size, stride, padding, fn_name):
torch._check(
indices.dtype == torch.int64, lambda: "elements in indices should be type int64"
)
torch._check(
input.ndim in (4, 5),
lambda: f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with {input.ndim} dimensions.",
)
torch._check(
len(output_size) == 3,
lambda: (
f"There should be exactly three elements (depth, height, width) in output_size, "
f"but got {len(output_size)} elements."
),
)
torch._check(
len(stride) == 3,
lambda: f"There should be exactly three elements (depth, height, width) in stride, but got: {len(stride)} elements.",
)
torch._check(
len(padding) == 3,
lambda: f"There should be exactly three elements (depth, height, width) in padding, but got: {len(padding)} elements.",
)
torch._check(
input.shape == indices.shape,
lambda: (
f"Expected shape of indices to be same as that of the input tensor ({input.shape}) "
f"but got indices tensor with shape: {indices.shape}"
),
)

for i in range(1, input.ndim):
torch._check(
input.size(i) > 0,
lambda: (
f"{fn_name}: "
f"Expected input to have non-zero size for non-batch dimensions, "
f"but got {input.shape} with dimension {i} being empty."
),
)

torch._check(
stride[0] > 0 and stride[1] > 0 and stride[2] > 0,
lambda: f"strides should be greater than zero, but got stride: {stride}",
)


@register_meta(aten.max_unpool3d)
@out_wrapper()
def meta_max_unpool3d(self_, indices, output_size, stride, padding):
utils.alert_not_deterministic("max_unpooling3d_forward_out")

_max_unpooling3d_shape_check(
self_, indices, output_size, stride, padding, "max_unpooling3d()"
)

self = self_.contiguous()

odepth, oheight, owidth = output_size

if self_.ndim == 4:
nchannels = self.size(0)
result = self.new_empty((nchannels, odepth, oheight, owidth))
else:
nbatch = self.size(0)
nchannels = self.size(1)
result = self.new_empty((nbatch, nchannels, odepth, oheight, owidth))

return result


@register_meta(aten.grid_sampler_2d_backward.default)
def grid_sampler_2d_backward_meta(
grad_output,
Expand Down
20 changes: 20 additions & 0 deletions torch/_prims_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,6 +1803,26 @@ def clone_preserve_strides(x):
finally:
torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.ADInplaceOrView, old)


def alert_not_deterministic(caller: str):
if torch.are_deterministic_algorithms_enabled():
if torch.is_deterministic_algorithms_warn_only_enabled():
warnings.warn(
f"{caller} does not have a deterministic implementation, but you set "
f"'torch.use_deterministic_algorithms(True, warn_only=True)'. "
f"You can file an issue at https://github.com/pytorch/pytorch/issues "
f"to help us prioritize adding deterministic support for this operation.")
else:
torch._check(
False,
lambda: (f"{caller} does not have a deterministic implementation, but you set "
f"'torch.use_deterministic_algorithms(True)'. You can turn off "
f"determinism just for this operation, or you can use the "
f"'warn_only=True' option, if that's acceptable for your application. "
f"You can also file an issue at https://github.com/pytorch/pytorch/issues "
f"to help us prioritize adding deterministic support for this operation."))


class CUDARngStateHelper:
@staticmethod
def get_torch_state_as_tuple(fake_mode=nullcontext()):
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2093,14 +2093,14 @@ Tensor max_pool_double_backward(
AT_ASSERT(indices.dim() >= dim);
// handle non-empty inputs
if (indices.sym_numel() != 0) {
auto size = indices.sizes().slice(0, indices.dim() - dim).vec();
auto size = indices.sym_sizes().slice(0, indices.dim() - dim).vec();
size.push_back(-1);
auto indices_view = indices.view(size);
auto indices_view = indices.view_symint(size);
const auto memory_format = indices.suggest_memory_format();
return grad.contiguous(memory_format)
.view(size)
.view_symint(size)
.gather(-1, indices_view)
.view(indices.sizes());
.view_symint(indices.sym_sizes());
}
// handle empty inputs
else {
Expand Down

0 comments on commit c4a6f86

Please sign in to comment.