Skip to content

Commit

Permalink
[pt2] add meta for _adaptive_avg_pool3d_backward (pytorch#105816)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#105816
Approved by: https://github.com/ezyang
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Jul 26, 2023
1 parent 36ae359 commit 0c65a2d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
4 changes: 0 additions & 4 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2817,7 +2817,6 @@ def forward(self, x):
xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos...
xfail('median', ''), # could not find kernel
xfail('mode', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.adaptive_avg_pool3d', ''), # aten._adaptive_avg_pool3d_backward.default - couldn't ...
xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbo...
xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2...
skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te..
Expand All @@ -2829,7 +2828,6 @@ def forward(self, x):
xfail('nn.functional.fractional_max_pool3d', ''), # rand() received an invalid combination of arguments - g...
xfail('nn.functional.grid_sample', ''), # RuntimeError: aten.grid_sampler_3d.default - couldn't find sym ...
xfail('nn.functional.group_norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.interpolate', 'area'), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.interpolate', 'linear'), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st...
xfail('nn.functional.nll_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
Expand Down Expand Up @@ -2966,8 +2964,6 @@ def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op):
torch.nn.Transformer, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
torch.nn.TransformerEncoder, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool
torch.nn.AdaptiveAvgPool3d, # could not find kernel for aten._adaptive_avg_pool3d_backward.default at dispatch key
# DispatchKey.Meta
torch.nn.AdaptiveMaxPool2d, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.AdaptiveMaxPool3d, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.GroupNorm, # in native_group_norm_backward cpg, _rem = divmod(C, group)
Expand Down
7 changes: 7 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2241,6 +2241,13 @@ def meta__adaptive_avg_pool2d_backward(grad_out, self):
return self.new_empty(self.shape).to(memory_format=memory_format)


@register_meta(aten._adaptive_avg_pool3d_backward)
@out_wrapper()
def meta__adaptive_avg_pool3d_backward(grad_output, self):
_adaptive_pool_empty_output_check(grad_output, "adaptive_avg_pool3d_backward")
return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)


def _adaptive_pool_empty_output_check(grad_output: Tensor, arg_name: str):
ndim = grad_output.ndim
for i in range(1, ndim):
Expand Down

0 comments on commit 0c65a2d

Please sign in to comment.